[ADD] big bit on new import: pretty much everything but o2m

bzr revid: xmo@openerp.com-20120919114047-w4paoim95oxr91zb
This commit is contained in:
Xavier Morel 2012-09-19 13:40:47 +02:00
parent b5c89ad4f4
commit 9805c665c8
9 changed files with 1507 additions and 18 deletions

View File

@ -40,6 +40,7 @@ import wizard
import ir_config_parameter
import osv_memory_autovacuum
import ir_mail_server
import ir_fields
# vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4:

View File

@ -0,0 +1,193 @@
# -*- coding: utf-8 -*-
import functools
import operator
import warnings
from openerp.osv import orm, fields
from openerp.tools.translate import _
class ConversionNotFound(ValueError): pass
class ir_fields_converter(orm.Model):
_name = 'ir.fields.converter'
def to_field(self, cr, uid, model, column, fromtype=str, context=None):
""" Fetches a converter for the provided column object, from the
specified type.
A converter is simply a callable taking a value of type ``fromtype``
(or a composite of ``fromtype``, e.g. list or dict) and returning a
value acceptable for a write() on the column ``column``.
By default, tries to get a method on itself with a name matching the
pattern ``_$fromtype_$column._type`` and returns it.
:param cr: openerp cursor
:param uid: ID of user calling the converter
:param column: column object to generate a value for
:type column: :class:`fields._column`
:param type fromtype: type to convert to something fitting for ``column``
:param context: openerp request context
:return: a function (fromtype -> column.write_type), if a converter is found
:rtype: Callable | None
"""
# FIXME: return None
converter = getattr(
self, '_%s_to_%s' % (fromtype.__name__, column._type))
if not converter: return None
return functools.partial(
converter, cr, uid, model, column, context=context)
def _str_to_boolean(self, cr, uid, model, column, value, context=None):
return value.lower() not in ('', '0', 'false', 'off')
def _str_to_integer(self, cr, uid, model, column, value, context=None):
if not value: return False
return int(value)
def _str_to_float(self, cr, uid, model, column, value, context=None):
if not value: return False
return float(value)
def _str_to_char(self, cr, uid, model, column, value, context=None):
return value or False
def _str_to_text(self, cr, uid, model, column, value, context=None):
return value or False
def _get_translations(self, cr, uid, types, src, context):
Translations = self.pool['ir.translation']
tnx_ids = Translations.search(
cr, uid, [('type', 'in', types), ('src', '=', src)], context=context)
tnx = Translations.read(cr, uid, tnx_ids, ['value'], context=context)
return map(operator.itemgetter('value'), tnx)
def _str_to_selection(self, cr, uid, model, column, value, context=None):
selection = column.selection
if not isinstance(selection, (tuple, list)):
# FIXME: Don't pass context to avoid translations?
# Or just copy context & remove lang?
selection = selection(model, cr, uid)
for item, label in selection:
labels = self._get_translations(
cr, uid, ('selection', 'model'), label, context=context)
labels.append(label)
if value == unicode(item) or value in labels:
return item
raise ValueError(
_(u"Value '%s' not found in selection field '%%(field)s'") % (
value))
def db_id_for(self, cr, uid, model, column, subfield, value, context=None):
""" Finds a database id for the reference ``value`` in the referencing
subfield ``subfield`` of the provided column of the provided model.
:param cr: OpenERP cursor
:param uid: OpenERP user id
:param model: model to which the column belongs
:param column: relational column for which references are provided
:param subfield: a relational subfield allowing building of refs to
existing records: ``None`` for a name_get/name_search,
``id`` for an external id and ``.id`` for a database
id
:param value: value of the reference to match to an actual record
:param context: OpenERP request context
:return: a pair of the matched database identifier (if any) and the
translated user-readable name for the field
:rtype: (ID|None, unicode)
"""
id = None
RelatedModel = self.pool[column._obj]
if subfield == '.id':
field_type = _(u"database id")
try: tentative_id = int(value)
except ValueError: tentative_id = value
if RelatedModel.search(cr, uid, [('id', '=', tentative_id)],
context=context):
id = tentative_id
elif subfield == 'id':
field_type = _(u"external id")
if '.' in value:
module, xid = value.split('.', 1)
else:
module, xid = '', value
ModelData = self.pool['ir.model.data']
try:
md_id = ModelData._get_id(cr, uid, module, xid)
model_data = ModelData.read(cr, uid, [md_id], ['res_id'],
context=context)
if model_data:
id = model_data[0]['res_id']
except ValueError: pass # leave id is None
elif subfield is None:
field_type = _(u"name")
ids = RelatedModel.name_search(
cr, uid, name=value, operator='=', context=context)
if ids:
if len(ids) > 1:
warnings.warn(
_(u"Found multiple matches for field '%%(field)s' (%d matches)")
% (len(ids)), orm.ImportWarning)
id, _name = ids[0]
else:
raise Exception(u"Unknown sub-field '%s'" % subfield)
return id, field_type
def _referencing_subfield(self, record):
""" Checks the record for the subfields allowing referencing (an
existing record in an other table), errors out if it finds potential
conflicts (multiple referencing subfields) or non-referencing subfields
returns the name of the correct subfield.
:param record:
:return: the record subfield to use for referencing
:rtype: str
"""
# Can import by name_get, external id or database id
allowed_fields = set([None, 'id', '.id'])
fieldset = set(record.iterkeys())
if fieldset - allowed_fields:
raise ValueError(
_(u"Can not create Many-To-One records indirectly, import the field separately"))
if len(fieldset) > 1:
raise ValueError(
_(u"Ambiguous specification for field '%(field)s', only provide one of name, external id or database id"))
# only one field left possible, unpack
[subfield] = fieldset
return subfield
def _str_to_many2one(self, cr, uid, model, column, values, context=None):
# Should only be one record, unpack
[record] = values
subfield = self._referencing_subfield(record)
reference = record[subfield]
id, subfield_type = self.db_id_for(
cr, uid, model, column, subfield, reference, context=context)
if id is None:
raise ValueError(
_(u"No matching record found for %(field_type)s '%(value)s' in field '%%(field)s'")
% {'field_type': subfield_type, 'value': reference})
return id
def _str_to_many2many(self, cr, uid, model, column, value, context=None):
[record] = value
subfield = self._referencing_subfield(record)
ids = []
for reference in record[subfield].split(','):
id, subfield_type = self.db_id_for(
cr, uid, model, column, subfield, reference, context=context)
if id is None:
raise ValueError(
_(u"No matching record found for %(field_type)s '%(value)s' in field '%%(field)s'")
% {'field_type': subfield_type, 'value': reference})
ids.append(id)
return [(6, 0, ids)]
def _str_to_one2many(self, cr, uid, model, column, value, context=None):
return value

View File

@ -1588,19 +1588,32 @@ def field_to_dict(model, cr, user, field, context=None):
class column_info(object):
"""Struct containing details about an osv column, either one local to
its model, or one inherited via _inherits.
""" Struct containing details about an osv column, either one local to
its model, or one inherited via _inherits.
:attr name: name of the column
:attr column: column instance, subclass of osv.fields._column
:attr parent_model: if the column is inherited, name of the model
that contains it, None for local columns.
:attr parent_column: the name of the column containing the m2o
relationship to the parent model that contains
this column, None for local columns.
:attr original_parent: if the column is inherited, name of the original
parent model that contains it i.e in case of multilevel
inheritence, None for local columns.
.. attribute:: name
name of the column
.. attribute:: column
column instance, subclass of :class:`_column`
.. attribute:: parent_model
if the column is inherited, name of the model that contains it,
``None`` for local columns.
.. attribute:: parent_column
the name of the column containing the m2o relationship to the
parent model that contains this column, ``None`` for local columns.
.. attribute:: original_parent
if the column is inherited, name of the original parent model that
contains it i.e in case of multilevel inheritance, ``None`` for
local columns.
"""
def __init__(self, name, column, parent_model=None, parent_column=None, original_parent=None):
self.name = name
@ -1609,5 +1622,10 @@ class column_info(object):
self.parent_column = parent_column
self.original_parent = original_parent
def __str__(self):
return '%s(%s, %s, %s, %s, %s)' % (
self.__name__, self.name, self.column,
self.parent_model, self.parent_column, self.original_parent)
# vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4:

View File

@ -52,13 +52,17 @@ import re
import simplejson
import time
import types
import psycopg2
from lxml import etree
import warnings
import fields
import openerp
import openerp.netsvc as netsvc
import openerp.tools as tools
from openerp.tools.config import config
from openerp.tools.misc import CountingStream
from openerp.tools.safe_eval import safe_eval as eval
from openerp.tools.translate import _
from openerp import SUPERUSER_ID
@ -1242,7 +1246,7 @@ class BaseModel(object):
* The last item is currently unused, with no specific semantics
:param fields: list of fields to import
:param data: data to import
:param datas: data to import
:param mode: 'init' or 'update' for record creation
:param current_module: module name
:param noupdate: flag for record creation
@ -1438,6 +1442,199 @@ class BaseModel(object):
self._parent_store_compute(cr)
return position, 0, 0, 0
def load(self, cr, uid, fields, data, context=None):
"""
:param cr: cursor for the request
:param int uid: ID of the user attempting the data import
:param fields: list of fields to import, at the same index as the corresponding data
:type fields: list(str)
:param data: row-major matrix of data to import
:type data: list(list(str))
:param dict context:
:returns:
"""
cr.execute('SAVEPOINT model_load')
messages = []
fields = map(fix_import_export_id_paths, fields)
ModelData = self.pool['ir.model.data']
mode = 'init'
current_module = ''
noupdate = False
ids = []
for id, xid, record, info in self._convert_records(cr, uid,
self._extract_records(cr, uid, fields, data,
context=context, log=messages.append),
context=context, log=messages.append):
cr.execute('SAVEPOINT model_load_save')
try:
ids.append(ModelData._update(cr, uid, self._name,
current_module, record, mode=mode, xml_id=xid,
noupdate=noupdate, res_id=id, context=context))
cr.execute('RELEASE SAVEPOINT model_load_save')
except psycopg2.Error, e:
# Failed to write, log to messages, rollback savepoint (to
# avoid broken transaction) and keep going
cr.execute('ROLLBACK TO SAVEPOINT model_load_save')
messages.append(dict(info, type="error", message=str(e)))
if any(message['type'] == 'error' for message in messages):
cr.execute('ROLLBACK TO SAVEPOINT model_load')
return False, messages
return ids, messages
def _extract_records(self, cr, uid, fields_, data,
context=None, log=lambda a: None):
""" Generates record dicts from the data iterable.
The result is a generator of dicts mapping field names to raw
(unconverted, unvalidated) values.
For relational fields, if sub-fields were provided the value will be
a list of sub-records
The following sub-fields may be set on the record (by key):
* None is the name_get for the record (to use with name_create/name_search)
* "id" is the External ID for the record
* ".id" is the Database ID for the record
:param ImportLogger logger:
"""
columns = dict((k, v.column) for k, v in self._all_columns.iteritems())
# Fake columns to avoid special cases in extractor
columns[None] = fields.char('rec_name')
columns['id'] = fields.char('External ID')
columns['.id'] = fields.integer('Database ID')
# m2o fields can't be on multiple lines so exclude them from the
# is_relational field rows filter, but special-case it later on to
# be handled with relational fields (as it can have subfields)
is_relational = lambda field: columns[field]._type in ('one2many', 'many2many', 'many2one')
get_o2m_values = itemgetter_tuple(
[index for index, field in enumerate(fields_)
if columns[field[0]]._type == 'one2many'])
get_nono2m_values = itemgetter_tuple(
[index for index, field in enumerate(fields_)
if columns[field[0]]._type != 'one2many'])
# Checks if the provided row has any non-empty non-relational field
def only_o2m_values(row, f=get_nono2m_values, g=get_o2m_values):
return any(g(row)) and not any(f(row))
rows = CountingStream(data)
while True:
row = next(rows, None)
if row is None: return
record_row_index = rows.index
# copy non-relational fields to record dict
record = dict((field[0], value)
for field, value in itertools.izip(fields_, row)
if not is_relational(field[0]))
# Get all following rows which have relational values attached to
# the current record (no non-relational values)
# WARNING: replaces existing ``rows``
record_span, _rows = span(only_o2m_values, rows)
# stitch record row back on for relational fields
record_span = itertools.chain([row], record_span)
for relfield in set(
field[0] for field in fields_
if is_relational(field[0])):
column = columns[relfield]
# FIXME: how to not use _obj without relying on fields_get?
Model = self.pool[column._obj]
# copy stream to reuse for next relational field
fieldrows, record_span = itertools.tee(record_span)
# get only cells for this sub-field, should be strictly
# non-empty, field path [None] is for name_get column
indices, subfields = zip(*((index, field[1:] or [None])
for index, field in enumerate(fields_)
if field[0] == relfield))
# return all rows which have at least one value for the
# subfields of relfield
relfield_data = filter(any, map(itemgetter_tuple(indices), fieldrows))
record[relfield] = [subrecord
for subrecord, _subinfo in Model._extract_records(
cr, uid, subfields, relfield_data,
context=context, log=log)]
# Ensure full consumption of the span (and therefore advancement of
# ``rows``) even if there are no relational fields. Needs two as
# the code above stiched the row back on (so first call may only
# get the stiched row without advancing the underlying operator row
# itself)
next(record_span, None)
next(record_span, None)
# old rows consumption (by iterating the span) should be done here,
# at this point the old ``rows`` is 1 past `span` (either on the
# next record row or past ``StopIteration``, so wrap new ``rows``
# (``_rows``) in a counting stream indexed 1-before the old
# ``rows``
rows = CountingStream(_rows, rows.index - 1)
yield record, {'rows': {'from': record_row_index,'to': rows.index}}
def _convert_records(self, cr, uid, records,
context=None, log=lambda a: None):
""" Converts records from the source iterable (recursive dicts of
strings) into forms which can be written to the database (via
self.create or (ir.model.data)._update)
:param ImportLogger parent_logger:
:returns: a list of triplets of (id, xid, record)
:rtype: list((int|None, str|None, dict))
"""
Converter = self.pool['ir.fields.converter']
columns = dict((k, v.column) for k, v in self._all_columns.iteritems())
converters = dict(
(k, Converter.to_field(cr, uid, self, column, context=context))
for k, column in columns.iteritems())
stream = CountingStream(records)
for record, extras in stream:
dbid = False
xid = False
converted = {}
# name_get/name_create
if None in record: pass
# xid
if 'id' in record:
xid = record['id']
# dbid
if '.id' in record:
try:
dbid = int(record['.id'])
except ValueError:
# in case of overridden id column
dbid = record['.id']
if not self.search(cr, uid, [('id', '=', dbid)], context=context):
log(dict(extras,
type='error',
record=stream.index,
field='.id',
message=_(u"Unknown database identifier '%s'") % dbid))
dbid = False
for field, strvalue in record.iteritems():
if field in (None, 'id', '.id'): continue
message_base = dict(extras, record=stream.index, field=field)
with warnings.catch_warnings(record=True) as w:
try:
converted[field] = converters[field](strvalue)
for warning in w:
log(dict(message_base, type='warning',
message=unicode(warning.message) % message_base))
except ValueError, e:
log(dict(message_base,
type='error',
message=unicode(e) % message_base
))
yield dbid, xid, converted, dict(extras, record=stream.index)
def get_invalid_fields(self, cr, uid):
return list(self._invalids)
@ -5108,5 +5305,32 @@ class AbstractModel(BaseModel):
_auto = False # don't create any database backend for AbstractModels
_register = False # not visible in ORM registry, meant to be python-inherited only
def span(predicate, iterable):
""" Splits the iterable between the longest prefix of ``iterable`` whose
elements satisfy ``predicate`` and the rest.
If called with a list, equivalent to::
takewhile(predicate, lst), dropwhile(predicate, lst)
:param callable predicate:
:param iterable:
:rtype: (iterable, iterable)
"""
it1, it2 = itertools.tee(iterable)
return (itertools.takewhile(predicate, it1),
itertools.dropwhile(predicate, it2))
def itemgetter_tuple(items):
""" Fixes itemgetter inconsistency (useful in some cases) of not returning
a tuple if len(items) == 1: always returns an n-tuple where n = len(items)
"""
if len(items) == 0:
return lambda a: ()
if len(items) == 1:
return lambda gettable: (gettable[items[0]],)
return operator.itemgetter(*items)
class ImportWarning(Warning):
""" Used to send warnings upwards the stack during the import process
"""
pass
# vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4:

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
from . import test_export, test_import
from . import test_export, test_import, test_load
fast_suite = [
]
@ -8,6 +8,7 @@ fast_suite = [
checks = [
test_export,
test_import,
test_load,
]
# vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4:

View File

@ -0,0 +1,987 @@
# -*- coding: utf-8 -*-
import openerp.modules.registry
import openerp
from openerp.tests import common
from openerp.tools.misc import mute_logger
def message(msg, type='error', from_=0, to_=0, record=0, field='value'):
return {
'type': type,
'rows': {'from': from_, 'to': to_},
'record': record,
'field': field,
'message': msg
}
def error(row, message, record=None, **kwargs):
""" Failed import of the record ``record`` at line ``row``, with the error
message ``message``
:param str message:
:param dict record:
"""
return (
-1, dict(record or {}, **kwargs),
"Line %d : %s" % (row, message),
'')
def values(seq, field='value'):
return [item[field] for item in seq]
class ImporterCase(common.TransactionCase):
model_name = False
def __init__(self, *args, **kwargs):
super(ImporterCase, self).__init__(*args, **kwargs)
self.model = None
def setUp(self):
super(ImporterCase, self).setUp()
self.model = self.registry(self.model_name)
self.registry('ir.model.data').clear_caches()
def import_(self, fields, rows, context=None):
return self.model.load(
self.cr, openerp.SUPERUSER_ID, fields, rows, context=context)
def read(self, fields=('value',), domain=(), context=None):
return self.model.read(
self.cr, openerp.SUPERUSER_ID,
self.model.search(self.cr, openerp.SUPERUSER_ID, domain, context=context),
fields=fields, context=context)
def browse(self, domain=(), context=None):
return self.model.browse(
self.cr, openerp.SUPERUSER_ID,
self.model.search(self.cr, openerp.SUPERUSER_ID, domain, context=context),
context=context)
def xid(self, record):
ModelData = self.registry('ir.model.data')
ids = ModelData.search(
self.cr, openerp.SUPERUSER_ID,
[('model', '=', record._table_name), ('res_id', '=', record.id)])
if ids:
d = ModelData.read(
self.cr, openerp.SUPERUSER_ID, ids, ['name', 'module'])[0]
if d['module']:
return '%s.%s' % (d['module'], d['name'])
return d['name']
name = dict(record.name_get())[record.id]
# fix dotted name_get results, otherwise xid lookups blow up
name = name.replace('.', '-')
ModelData.create(self.cr, openerp.SUPERUSER_ID, {
'name': name,
'model': record._table_name,
'res_id': record.id,
'module': '__test__'
})
return '__test__.' + name
class test_ids_stuff(ImporterCase):
model_name = 'export.integer'
def test_create_with_id(self):
ids, messages = self.import_(['.id', 'value'], [['42', '36']])
self.assertIs(ids, False)
self.assertEqual(messages, [{
'type': 'error',
'rows': {'from': 0, 'to': 0},
'record': 0,
'field': '.id',
'message': u"Unknown database identifier '42'",
}])
def test_create_with_xid(self):
ids, messages = self.import_(['id', 'value'], [['somexmlid', '42']])
self.assertEqual(len(ids), 1)
self.assertFalse(messages)
self.assertEqual(
'somexmlid',
self.xid(self.browse()[0]))
def test_update_with_id(self):
id = self.model.create(self.cr, openerp.SUPERUSER_ID, {'value': 36})
self.assertEqual(
36,
self.model.browse(self.cr, openerp.SUPERUSER_ID, id).value)
ids, messages = self.import_(['.id', 'value'], [[str(id), '42']])
self.assertEqual(len(ids), 1)
self.assertFalse(messages)
self.assertEqual(
[42], # updated value to imported
values(self.read()))
def test_update_with_xid(self):
self.import_(['id', 'value'], [['somexmlid', '36']])
self.assertEqual([36], values(self.read()))
self.import_(['id', 'value'], [['somexmlid', '1234567']])
self.assertEqual([1234567], values(self.read()))
class test_boolean_field(ImporterCase):
model_name = 'export.boolean'
def test_empty(self):
self.assertEqual(
self.import_(['value'], []),
([], []))
def test_exported(self):
ids, messages = self.import_(['value'], [['False'], ['True'], ])
self.assertEqual(len(ids), 2)
self.assertFalse(messages)
records = self.read()
self.assertEqual([
False,
True,
], values(records))
def test_falses(self):
ids, messages = self.import_(
['value'],
[[u'0'], [u'off'],
[u'false'], [u'FALSE'],
[u'OFF'], [u''],
])
self.assertEqual(len(ids), 6)
self.assertFalse(messages)
self.assertEqual([
False,
False,
False,
False,
False,
False,
],
values(self.read()))
def test_trues(self):
ids, messages = self.import_(
['value'],
[['no'],
['None'],
['nil'],
['()'],
['f'],
['#f'],
# Problem: OpenOffice (and probably excel) output localized booleans
['VRAI'],
])
self.assertEqual(len(ids), 7)
# FIXME: should warn for values which are not "true", "yes" or "1"
self.assertFalse(messages)
self.assertEqual(
[True] * 7,
values(self.read()))
class test_integer_field(ImporterCase):
model_name = 'export.integer'
def test_none(self):
self.assertEqual(
self.import_(['value'], []),
([], []))
def test_empty(self):
ids, messages = self.import_(['value'], [['']])
self.assertEqual(len(ids), 1)
self.assertFalse(messages)
self.assertEqual(
[False],
values(self.read()))
def test_zero(self):
ids, messages = self.import_(['value'], [['0']])
self.assertEqual(len(ids), 1)
self.assertFalse(messages)
ids, messages = self.import_(['value'], [['-0']])
self.assertEqual(len(ids), 1)
self.assertFalse(messages)
self.assertEqual([False, False], values(self.read()))
def test_positives(self):
ids, messages = self.import_(['value'], [
['1'],
['42'],
[str(2**31-1)],
['12345678']
])
self.assertEqual(len(ids), 4)
self.assertFalse(messages)
self.assertEqual([
1, 42, 2**31-1, 12345678
], values(self.read()))
def test_negatives(self):
ids, messages = self.import_(['value'], [
['-1'],
['-42'],
[str(-(2**31 - 1))],
[str(-(2**31))],
['-12345678']
])
self.assertEqual(len(ids), 5)
self.assertFalse(messages)
self.assertEqual([
-1, -42, -(2**31 - 1), -(2**31), -12345678
], values(self.read()))
@mute_logger('openerp.sql_db')
def test_out_of_range(self):
ids, messages = self.import_(['value'], [[str(2**31)]])
self.assertIs(ids, False)
self.assertEqual(messages, [{
'type': 'error',
'rows': {'from': 0, 'to': 0},
'record': 0,
'message': "integer out of range\n"
}])
ids, messages = self.import_(['value'], [[str(-2**32)]])
self.assertIs(ids, False)
self.assertEqual(messages, [{
'type': 'error',
'rows': {'from': 0, 'to': 0},
'record': 0,
'message': "integer out of range\n"
}])
def test_nonsense(self):
ids, messages = self.import_(['value'], [['zorglub']])
self.assertIs(ids, False)
self.assertEqual(messages, [{
'type': 'error',
'rows': {'from': 0, 'to': 0},
'record': 0,
'field': 'value',
'message': u"invalid literal for int() with base 10: 'zorglub'",
}])
class test_float_field(ImporterCase):
model_name = 'export.float'
def test_none(self):
self.assertEqual(
self.import_(['value'], []),
([], []))
def test_empty(self):
ids, messages = self.import_(['value'], [['']])
self.assertEqual(len(ids), 1)
self.assertFalse(messages)
self.assertEqual(
[False],
values(self.read()))
def test_zero(self):
ids, messages = self.import_(['value'], [['0']])
self.assertEqual(len(ids), 1)
self.assertFalse(messages)
ids, messages = self.import_(['value'], [['-0']])
self.assertEqual(len(ids), 1)
self.assertFalse(messages)
self.assertEqual([False, False], values(self.read()))
def test_positives(self):
ids, messages = self.import_(['value'], [
['1'],
['42'],
[str(2**31-1)],
['12345678'],
[str(2**33)],
['0.000001'],
])
self.assertEqual(len(ids), 6)
self.assertFalse(messages)
self.assertEqual([
1, 42, 2**31-1, 12345678, 2.0**33, .000001
], values(self.read()))
def test_negatives(self):
ids, messages = self.import_(['value'], [
['-1'],
['-42'],
[str(-2**31 + 1)],
[str(-2**31)],
['-12345678'],
[str(-2**33)],
['-0.000001'],
])
self.assertEqual(len(ids), 7)
self.assertFalse(messages)
self.assertEqual([
-1, -42, -(2**31 - 1), -(2**31), -12345678, -2.0**33, -.000001
], values(self.read()))
def test_nonsense(self):
ids, messages = self.import_(['value'], [['foobar']])
self.assertIs(ids, False)
self.assertEqual(messages, [{
'type': 'error',
'rows': {'from': 0, 'to': 0},
'record': 0,
'field': 'value',
'message': u"invalid literal for float(): foobar",
}])
class test_string_field(ImporterCase):
model_name = 'export.string.bounded'
def test_empty(self):
ids, messages = self.import_(['value'], [['']])
self.assertEqual(len(ids), 1)
self.assertFalse(messages)
self.assertEqual([False], values(self.read()))
def test_imported(self):
ids, messages = self.import_(['value'], [
[u'foobar'],
[u'foobarbaz'],
[u'Með suð í eyrum við spilum endalaust'],
[u"People 'get' types. They use them all the time. Telling "
u"someone he can't pound a nail with a banana doesn't much "
u"surprise him."]
])
self.assertEqual(len(ids), 4)
self.assertFalse(messages)
self.assertEqual([
u"foobar",
u"foobarbaz",
u"Með suð í eyrum ",
u"People 'get' typ",
], values(self.read()))
class test_unbound_string_field(ImporterCase):
model_name = 'export.string'
def test_imported(self):
ids, messages = self.import_(['value'], [
[u'í dag viðrar vel til loftárása'],
# ackbar.jpg
[u"If they ask you about fun, you tell them fun is a filthy"
u" parasite"]
])
self.assertEqual(len(ids), 2)
self.assertFalse(messages)
self.assertEqual([
u"í dag viðrar vel til loftárása",
u"If they ask you about fun, you tell them fun is a filthy parasite"
], values(self.read()))
class test_text(ImporterCase):
model_name = 'export.text'
def test_empty(self):
ids, messages = self.import_(['value'], [['']])
self.assertEqual(len(ids), 1)
self.assertFalse(messages)
self.assertEqual([False], values(self.read()))
def test_imported(self):
s = (u"Breiðskífa er notað um útgefna hljómplötu sem inniheldur "
u"stúdíóupptökur frá einum flytjanda. Breiðskífur eru oftast "
u"milli 25-80 mínútur og er lengd þeirra oft miðuð við 33⅓ "
u"snúninga 12 tommu vínylplötur (sem geta verið allt að 30 mín "
u"hvor hlið).\n\nBreiðskífur eru stundum tvöfaldar og eru þær þá"
u" gefnar út á tveimur geisladiskum eða tveimur vínylplötum.")
ids, messages = self.import_(['value'], [[s]])
self.assertEqual(len(ids), 1)
self.assertFalse(messages)
self.assertEqual([s], values(self.read()))
class test_selection(ImporterCase):
model_name = 'export.selection'
translations_fr = [
("Qux", "toto"),
("Bar", "titi"),
("Foo", "tete"),
]
def test_imported(self):
ids, messages = self.import_(['value'], [
['Qux'],
['Bar'],
['Foo'],
['2'],
])
self.assertEqual(len(ids), 4)
self.assertFalse(messages)
self.assertEqual([3, 2, 1, 2], values(self.read()))
def test_imported_translated(self):
self.registry('res.lang').create(self.cr, openerp.SUPERUSER_ID, {
'name': u'Français',
'code': 'fr_FR',
'translatable': True,
'date_format': '%d.%m.%Y',
'decimal_point': ',',
'thousand_sep': ' ',
})
Translations = self.registry('ir.translation')
for source, value in self.translations_fr:
Translations.create(self.cr, openerp.SUPERUSER_ID, {
'name': 'export.selection,value',
'lang': 'fr_FR',
'type': 'selection',
'src': source,
'value': value
})
ids, messages = self.import_(['value'], [
['toto'],
['tete'],
['titi'],
], context={'lang': 'fr_FR'})
self.assertEqual(len(ids), 3)
self.assertFalse(messages)
self.assertEqual([3, 1, 2], values(self.read()))
ids, messages = self.import_(['value'], [['Foo']], context={'lang': 'fr_FR'})
self.assertEqual(len(ids), 1)
self.assertFalse(messages)
def test_invalid(self):
ids, messages = self.import_(['value'], [['Baz']])
self.assertIs(ids, False)
self.assertEqual(messages, [{
'type': 'error',
'rows': {'from': 0, 'to': 0},
'record': 0,
'field': 'value',
'message': "Value 'Baz' not found in selection field 'value'",
}])
ids, messages = self.import_(['value'], [[42]])
self.assertIs(ids, False)
self.assertEqual(messages, [{
'type': 'error',
'rows': {'from': 0, 'to': 0},
'record': 0,
'field': 'value',
'message': "Value '42' not found in selection field 'value'",
}])
class test_selection_function(ImporterCase):
model_name = 'export.selection.function'
translations_fr = [
("Corge", "toto"),
("Grault", "titi"),
("Whee", "tete"),
("Moog", "tutu"),
]
def test_imported(self):
""" import uses fields_get, so translates import label (may or may not
be good news) *and* serializes the selection function to reverse it:
import does not actually know that the selection field uses a function
"""
# NOTE: conflict between a value and a label => ?
ids, messages = self.import_(['value'], [
['3'],
["Grault"],
])
self.assertEqual(len(ids), 2)
self.assertFalse(messages)
self.assertEqual(
['3', '1'],
values(self.read()))
def test_translated(self):
""" Expects output of selection function returns translated labels
"""
self.registry('res.lang').create(self.cr, openerp.SUPERUSER_ID, {
'name': u'Français',
'code': 'fr_FR',
'translatable': True,
'date_format': '%d.%m.%Y',
'decimal_point': ',',
'thousand_sep': ' ',
})
Translations = self.registry('ir.translation')
for source, value in self.translations_fr:
Translations.create(self.cr, openerp.SUPERUSER_ID, {
'name': 'export.selection,value',
'lang': 'fr_FR',
'type': 'selection',
'src': source,
'value': value
})
ids, messages = self.import_(['value'], [
['toto'],
['tete'],
], context={'lang': 'fr_FR'})
self.assertIs(ids, False)
self.assertEqual(messages, [{
'type': 'error',
'rows': {'from': 1, 'to': 1},
'record': 1,
'field': 'value',
'message': "Value 'tete' not found in selection field 'value'",
}])
ids, messages = self.import_(['value'], [['Wheee']], context={'lang': 'fr_FR'})
self.assertEqual(len(ids), 1)
self.assertFalse(messages)
class test_m2o(ImporterCase):
model_name = 'export.many2one'
def test_by_name(self):
# create integer objects
integer_id1 = self.registry('export.integer').create(
self.cr, openerp.SUPERUSER_ID, {'value': 42})
integer_id2 = self.registry('export.integer').create(
self.cr, openerp.SUPERUSER_ID, {'value': 36})
# get its name
name1 = dict(self.registry('export.integer').name_get(
self.cr, openerp.SUPERUSER_ID,[integer_id1]))[integer_id1]
name2 = dict(self.registry('export.integer').name_get(
self.cr, openerp.SUPERUSER_ID,[integer_id2]))[integer_id2]
ids , messages = self.import_(['value'], [
# import by name_get
[name1],
[name1],
[name2],
])
self.assertFalse(messages)
self.assertEqual(len(ids), 3)
# correct ids assigned to corresponding records
self.assertEqual([
(integer_id1, name1),
(integer_id1, name1),
(integer_id2, name2),],
values(self.read()))
def test_by_xid(self):
ExportInteger = self.registry('export.integer')
integer_id = ExportInteger.create(
self.cr, openerp.SUPERUSER_ID, {'value': 42})
xid = self.xid(ExportInteger.browse(
self.cr, openerp.SUPERUSER_ID, [integer_id])[0])
ids, messages = self.import_(['value/id'], [[xid]])
self.assertFalse(messages)
self.assertEqual(len(ids), 1)
b = self.browse()
self.assertEqual(42, b[0].value.value)
def test_by_id(self):
integer_id = self.registry('export.integer').create(
self.cr, openerp.SUPERUSER_ID, {'value': 42})
ids, messages = self.import_(['value/.id'], [[integer_id]])
self.assertFalse(messages)
self.assertEqual(len(ids), 1)
b = self.browse()
self.assertEqual(42, b[0].value.value)
def test_by_names(self):
integer_id1 = self.registry('export.integer').create(
self.cr, openerp.SUPERUSER_ID, {'value': 42})
integer_id2 = self.registry('export.integer').create(
self.cr, openerp.SUPERUSER_ID, {'value': 42})
name1 = dict(self.registry('export.integer').name_get(
self.cr, openerp.SUPERUSER_ID,[integer_id1]))[integer_id1]
name2 = dict(self.registry('export.integer').name_get(
self.cr, openerp.SUPERUSER_ID,[integer_id2]))[integer_id2]
# names should be the same
self.assertEqual(name1, name2)
ids, messages = self.import_(['value'], [[name2]])
self.assertEqual(
messages,
[message(u"Found multiple matches for field 'value' (2 matches)",
type='warning')])
self.assertEqual(len(ids), 1)
self.assertEqual([
(integer_id1, name1)
], values(self.read()))
def test_fail_by_implicit_id(self):
""" Can't implicitly import records by id
"""
# create integer objects
integer_id1 = self.registry('export.integer').create(
self.cr, openerp.SUPERUSER_ID, {'value': 42})
integer_id2 = self.registry('export.integer').create(
self.cr, openerp.SUPERUSER_ID, {'value': 36})
# Because name_search all the things. Fallback schmallback
ids, messages = self.import_(['value'], [
# import by id, without specifying it
[integer_id1],
[integer_id2],
[integer_id1],
])
self.assertEqual(messages, [
message(u"No matching record found for name '%s' in field 'value'" % id,
from_=index, to_=index, record=index)
for index, id in enumerate([integer_id1, integer_id2, integer_id1])])
self.assertIs(ids, False)
def test_sub_field(self):
""" Does not implicitly create the record, does not warn that you can't
import m2o subfields (at all)...
"""
ids, messages = self.import_(['value/value'], [['42']])
self.assertEqual(messages, [
message(u"Can not create Many-To-One records indirectly, import "
u"the field separately")])
self.assertIs(ids, False)
def test_fail_noids(self):
ids, messages = self.import_(['value'], [['nameisnoexist:3']])
self.assertEqual(messages, [message(
u"No matching record found for name 'nameisnoexist:3' "
u"in field 'value'")])
self.assertIs(ids, False)
ids, messages = self.import_(['value/id'], [['noxidhere']])
self.assertEqual(messages, [message(
u"No matching record found for external id 'noxidhere' "
u"in field 'value'")])
self.assertIs(ids, False)
ids, messages = self.import_(['value/.id'], [['66']])
self.assertEqual(messages, [message(
u"No matching record found for database id '66' "
u"in field 'value'")])
self.assertIs(ids, False)
def test_fail_multiple(self):
ids, messages = self.import_(
['value', 'value/id'],
[['somename', 'somexid']])
self.assertEqual(messages, [message(
u"Ambiguous specification for field 'value', only provide one of "
u"name, external id or database id")])
self.assertIs(ids, False)
class test_m2m(ImporterCase):
model_name = 'export.many2many'
# apparently, one and only thing which works is a
# csv_internal_sep-separated list of ids, xids, or names (depending if
# m2m/.id, m2m/id or m2m[/anythingelse]
def test_ids(self):
id1 = self.registry('export.many2many.other').create(
self.cr, openerp.SUPERUSER_ID, {'value': 3, 'str': 'record0'})
id2 = self.registry('export.many2many.other').create(
self.cr, openerp.SUPERUSER_ID, {'value': 44, 'str': 'record1'})
id3 = self.registry('export.many2many.other').create(
self.cr, openerp.SUPERUSER_ID, {'value': 84, 'str': 'record2'})
id4 = self.registry('export.many2many.other').create(
self.cr, openerp.SUPERUSER_ID, {'value': 9, 'str': 'record3'})
id5 = self.registry('export.many2many.other').create(
self.cr, openerp.SUPERUSER_ID, {'value': 99, 'str': 'record4'})
ids, messages = self.import_(['value/.id'], [
['%d,%d' % (id1, id2)],
['%d,%d,%d' % (id1, id3, id4)],
['%d,%d,%d' % (id1, id2, id3)],
['%d' % id5]
])
self.assertFalse(messages)
self.assertEqual(len(ids), 4)
ids = lambda records: [record.id for record in records]
b = self.browse()
self.assertEqual(ids(b[0].value), [id1, id2])
self.assertEqual(values(b[0].value), [3, 44])
self.assertEqual(ids(b[2].value), [id1, id2, id3])
self.assertEqual(values(b[2].value), [3, 44, 84])
def test_noids(self):
ids, messages = self.import_(['value/.id'], [['42']])
self.assertEqual(messages, [message(
u"No matching record found for database id '42' in field "
u"'value'")])
self.assertIs(ids, False)
def test_xids(self):
M2O_o = self.registry('export.many2many.other')
id1 = M2O_o.create(self.cr, openerp.SUPERUSER_ID, {'value': 3, 'str': 'record0'})
id2 = M2O_o.create(self.cr, openerp.SUPERUSER_ID, {'value': 44, 'str': 'record1'})
id3 = M2O_o.create(self.cr, openerp.SUPERUSER_ID, {'value': 84, 'str': 'record2'})
id4 = M2O_o.create(self.cr, openerp.SUPERUSER_ID, {'value': 9, 'str': 'record3'})
records = M2O_o.browse(self.cr, openerp.SUPERUSER_ID, [id1, id2, id3, id4])
ids, messages = self.import_(['value/id'], [
['%s,%s' % (self.xid(records[0]), self.xid(records[1]))],
['%s' % self.xid(records[3])],
['%s,%s' % (self.xid(records[2]), self.xid(records[1]))],
])
self.assertFalse(messages)
self.assertEqual(len(ids), 3)
b = self.browse()
self.assertEqual(values(b[0].value), [3, 44])
self.assertEqual(values(b[2].value), [44, 84])
def test_noxids(self):
ids, messages = self.import_(['value/id'], [['noxidforthat']])
self.assertEqual(messages, [message(
u"No matching record found for external id 'noxidforthat' "
u"in field 'value'")])
self.assertIs(ids, False)
def test_names(self):
M2O_o = self.registry('export.many2many.other')
id1 = M2O_o.create(self.cr, openerp.SUPERUSER_ID, {'value': 3, 'str': 'record0'})
id2 = M2O_o.create(self.cr, openerp.SUPERUSER_ID, {'value': 44, 'str': 'record1'})
id3 = M2O_o.create(self.cr, openerp.SUPERUSER_ID, {'value': 84, 'str': 'record2'})
id4 = M2O_o.create(self.cr, openerp.SUPERUSER_ID, {'value': 9, 'str': 'record3'})
records = M2O_o.browse(self.cr, openerp.SUPERUSER_ID, [id1, id2, id3, id4])
name = lambda record: dict(record.name_get())[record.id]
ids, messages = self.import_(['value'], [
['%s,%s' % (name(records[1]), name(records[2]))],
['%s,%s,%s' % (name(records[0]), name(records[1]), name(records[2]))],
['%s,%s' % (name(records[0]), name(records[3]))],
])
self.assertFalse(messages)
self.assertEqual(len(ids), 3)
b = self.browse()
self.assertEqual(values(b[1].value), [3, 44, 84])
self.assertEqual(values(b[2].value), [3, 9])
def test_nonames(self):
ids, messages = self.import_(['value'], [['wherethem2mhavenonames']])
self.assertEqual(messages, [message(
u"No matching record found for name 'wherethem2mhavenonames' in "
u"field 'value'")])
self.assertIs(ids, False)
def test_import_to_existing(self):
M2O_o = self.registry('export.many2many.other')
id1 = M2O_o.create(self.cr, openerp.SUPERUSER_ID, {'value': 3, 'str': 'record0'})
id2 = M2O_o.create(self.cr, openerp.SUPERUSER_ID, {'value': 44, 'str': 'record1'})
id3 = M2O_o.create(self.cr, openerp.SUPERUSER_ID, {'value': 84, 'str': 'record2'})
id4 = M2O_o.create(self.cr, openerp.SUPERUSER_ID, {'value': 9, 'str': 'record3'})
xid = 'myxid'
ids, messages = self.import_(['id', 'value/.id'], [[xid, '%d,%d' % (id1, id2)]])
self.assertFalse(messages)
self.assertEqual(len(ids), 1)
ids, messages = self.import_(['id', 'value/.id'], [[xid, '%d,%d' % (id3, id4)]])
self.assertFalse(messages)
self.assertEqual(len(ids), 1)
b = self.browse()
self.assertEqual(len(b), 1)
# TODO: replacement of existing m2m values is correct?
self.assertEqual(values(b[0].value), [84, 9])
class test_o2m(ImporterCase):
model_name = 'export.one2many'
def test_name_get(self):
# FIXME: bloody hell why can't this just name_create the record?
self.assertRaises(
IndexError,
self.import_,
['const', 'value'],
[['5', u'Java is a DSL for taking large XML files'
u' and converting them to stack traces']])
def test_single(self):
ids, messages = self.import_(['const', 'value/value'], [
['5', '63']
])
self.assertEqual(len(ids), 1)
self.assertFalse(messages)
(b,) = self.browse()
self.assertEqual(b.const, 5)
self.assertEqual(values(b.value), [63])
def test_multicore(self):
ids, messages = self.import_(['const', 'value/value'], [
['5', '63'],
['6', '64'],
])
self.assertEqual(len(ids), 2)
self.assertFalse(messages)
b1, b2 = self.browse()
self.assertEqual(b1.const, 5)
self.assertEqual(values(b1.value), [63])
self.assertEqual(b2.const, 6)
self.assertEqual(values(b2.value), [64])
def test_multisub(self):
ids, messages = self.import_(['const', 'value/value'], [
['5', '63'],
['', '64'],
['', '65'],
['', '66'],
])
self.assertEqual(len(ids), 4)
self.assertFalse(messages)
(b,) = self.browse()
self.assertEqual(values(b.value), [63, 64, 65, 66])
def test_multi_subfields(self):
ids, messages = self.import_(['value/str', 'const', 'value/value'], [
['this', '5', '63'],
['is', '', '64'],
['the', '', '65'],
['rhythm', '', '66'],
])
self.assertEqual(len(ids), 4)
self.assertFalse(messages)
(b,) = self.browse()
self.assertEqual(values(b.value), [63, 64, 65, 66])
self.assertEqual(
values(b.value, 'str'),
'this is the rhythm'.split())
def test_link_inline(self):
id1 = self.registry('export.one2many.child').create(self.cr, openerp.SUPERUSER_ID, {
'str': 'Bf', 'value': 109
})
id2 = self.registry('export.one2many.child').create(self.cr, openerp.SUPERUSER_ID, {
'str': 'Me', 'value': 262
})
try:
self.import_(['const', 'value/.id'], [
['42', '%d,%d' % (id1, id2)]
])
self.fail("Should have raised a valueerror")
except ValueError, e:
# should be Exception(Database ID doesn't exist: export.one2many.child : $id1,$id2)
self.assertIs(type(e), ValueError)
self.assertEqual(
e.args[0],
"invalid literal for int() with base 10: '%d,%d'" % (id1, id2))
def test_link(self):
id1 = self.registry('export.one2many.child').create(self.cr, openerp.SUPERUSER_ID, {
'str': 'Bf', 'value': 109
})
id2 = self.registry('export.one2many.child').create(self.cr, openerp.SUPERUSER_ID, {
'str': 'Me', 'value': 262
})
ids, messages = self.import_(['const', 'value/.id'], [
['42', str(id1)],
['', str(id2)],
])
self.assertEqual(len(ids), 2)
self.assertFalse(messages)
# No record values alongside id => o2m resolution skipped altogether,
# creates 2 records => remove/don't import columns sideshow columns,
# get completely different semantics
b, b1 = self.browse()
self.assertEqual(b.const, 42)
self.assertEqual(values(b.value), [])
self.assertEqual(b1.const, 4)
self.assertEqual(values(b1.value), [])
def test_link_2(self):
O2M_c = self.registry('export.one2many.child')
id1 = O2M_c.create(self.cr, openerp.SUPERUSER_ID, {
'str': 'Bf', 'value': 109
})
id2 = O2M_c.create(self.cr, openerp.SUPERUSER_ID, {
'str': 'Me', 'value': 262
})
ids, messages = self.import_(['const', 'value/.id', 'value/value'], [
['42', str(id1), '1'],
['', str(id2), '2'],
])
self.assertEqual(len(ids), 2)
self.assertFalse(messages)
(b,) = self.browse()
# if an id (db or xid) is provided, expectations that objects are
# *already* linked and emits UPDATE (1, id, {}).
# Noid => CREATE (0, ?, {})
# TODO: xid ignored aside from getting corresponding db id?
self.assertEqual(b.const, 42)
self.assertEqual(values(b.value), [])
# FIXME: updates somebody else's records?
self.assertEqual(
O2M_c.read(self.cr, openerp.SUPERUSER_ID, id1),
{'id': id1, 'str': 'Bf', 'value': 1, 'parent_id': False})
self.assertEqual(
O2M_c.read(self.cr, openerp.SUPERUSER_ID, id2),
{'id': id2, 'str': 'Me', 'value': 2, 'parent_id': False})
class test_o2m_multiple(ImporterCase):
model_name = 'export.one2many.multiple'
def test_multi_mixed(self):
ids, messages = self.import_(['const', 'child1/value', 'child2/value'], [
['5', '11', '21'],
['', '12', '22'],
['', '13', '23'],
['', '14', ''],
])
self.assertEqual(len(ids), 4)
self.assertFalse(messages)
# Oh yeah, that's the stuff
(b, b1, b2) = self.browse()
self.assertEqual(values(b.child1), [11])
self.assertEqual(values(b.child2), [21])
self.assertEqual(values(b1.child1), [12])
self.assertEqual(values(b1.child2), [22])
self.assertEqual(values(b2.child1), [13, 14])
self.assertEqual(values(b2.child2), [23])
def test_multi(self):
ids, messages = self.import_(['const', 'child1/value', 'child2/value'], [
['5', '11', '21'],
['', '12', ''],
['', '13', ''],
['', '14', ''],
['', '', '22'],
['', '', '23'],
])
self.assertEqual(len(ids), 6)
self.assertFalse(messages)
# What the actual fuck?
(b, b1) = self.browse()
self.assertEqual(values(b.child1), [11, 12, 13, 14])
self.assertEqual(values(b.child2), [21])
self.assertEqual(values(b1.child2), [22, 23])
def test_multi_fullsplit(self):
ids, messages = self.import_(['const', 'child1/value', 'child2/value'], [
['5', '11', ''],
['', '12', ''],
['', '13', ''],
['', '14', ''],
['', '', '21'],
['', '', '22'],
['', '', '23'],
])
self.assertEqual(len(ids), 7)
self.assertFalse(messages)
# oh wow
(b, b1) = self.browse()
self.assertEqual(b.const, 5)
self.assertEqual(values(b.child1), [11, 12, 13, 14])
self.assertEqual(b1.const, 36)
self.assertEqual(values(b1.child2), [21, 22, 23])
# function, related, reference: written to db as-is...
# => function uses @type for value coercion/conversion

View File

@ -2,12 +2,12 @@
# > PYTHONPATH=. python2 openerp/tests/test_misc.py
import unittest2
from ..tools import misc
class test_misc(unittest2.TestCase):
class append_content_to_html(unittest2.TestCase):
""" Test some of our generic utility functions """
def test_append_to_html(self):
from openerp.tools import append_content_to_html
test_samples = [
('<!DOCTYPE...><HTML encoding="blah">some <b>content</b></HtMl>', '--\nYours truly', True,
'<!DOCTYPE...><html encoding="blah">some <b>content</b>\n<pre>--\nYours truly</pre>\n</html>'),
@ -15,7 +15,37 @@ class test_misc(unittest2.TestCase):
'<html><body>some <b>content</b>\n\n\n<p>--</p>\n<p>Yours truly</p>\n\n\n</body></html>'),
]
for html, content, flag, expected in test_samples:
self.assertEqual(append_content_to_html(html,content,flag), expected, 'append_content_to_html is broken')
self.assertEqual(misc.append_content_to_html(html,content,flag), expected, 'append_content_to_html is broken')
class test_countingstream(unittest2.TestCase):
def test_empty_stream(self):
s = misc.CountingStream(iter([]))
self.assertEqual(s.index, -1)
self.assertIsNone(next(s, None))
self.assertEqual(s.index, 0)
def test_single(self):
s = misc.CountingStream(xrange(1))
self.assertEqual(s.index, -1)
self.assertEqual(next(s, None), 0)
self.assertIsNone(next(s, None))
self.assertEqual(s.index, 1)
def test_full(self):
s = misc.CountingStream(xrange(42))
for _ in s:
pass
self.assertEqual(s.index, 42)
def test_repeated(self):
""" Once the CountingStream has stopped iterating, the index should not
increase anymore (the internal state should not be allowed to change)
"""
s = misc.CountingStream(iter([]))
self.assertIsNone(next(s, None))
self.assertEqual(s.index, 0)
self.assertIsNone(next(s, None))
self.assertEqual(s.index, 0)
if __name__ == '__main__':
unittest2.main()
unittest2.main()

View File

@ -1220,4 +1220,38 @@ class mute_logger(object):
with self:
return func(*args, **kwargs)
return deco
_ph = object()
class CountingStream(object):
""" Stream wrapper counting the number of element it has yielded. Similar
role to ``enumerate``, but for use when the iteration process of the stream
isn't fully under caller control (the stream can be iterated from multiple
points including within a library)
``start`` allows overriding the starting index (the index before the first
item is returned).
On each iteration (call to :meth:`~.next`), increases its :attr:`~.index`
by one.
.. attribute:: index
``int``, index of the last yielded element in the stream. If the stream
has ended, will give an index 1-past the stream
"""
def __init__(self, stream, start=-1):
self.stream = iter(stream)
self.index = start
self.stopped = False
def __iter__(self):
return self
def next(self):
if self.stopped: raise StopIteration()
self.index += 1
val = next(self.stream, _ph)
if val is _ph:
self.stopped = True
raise StopIteration()
return val
# vim:expandtab:smartindent:tabstop=4:softtabstop=4:shiftwidth=4:

View File

@ -116,6 +116,7 @@ setuptools.setup(
extras_require = {
'SSL' : ['pyopenssl'],
},
tests_require = ['unittest2'],
**py2exe_options()
)