From 0eb99442224bb12bedc7150139796e8de64a7e1f Mon Sep 17 00:00:00 2001 From: Gery Debongnie Date: Mon, 24 Mar 2014 09:50:46 +0100 Subject: [PATCH] [IMP] refactoring of readgroup method. It should have the same functionality, but is much simpler. This is necessary to implement eager groupby (orm.py) bzr revid: ged@openerp.com-20140324085046-zpfpcostivf8454q --- openerp/osv/orm.py | 230 ++++++++++++++++++++------------------------- 1 file changed, 100 insertions(+), 130 deletions(-) diff --git a/openerp/osv/orm.py b/openerp/osv/orm.py index 64fa06783ab..cc8e7aa066b 100644 --- a/openerp/osv/orm.py +++ b/openerp/osv/orm.py @@ -2222,7 +2222,7 @@ class BaseModel(object): self._name, order_part) return groupby_terms, orderby_terms - def read_group(self, cr, uid, domain, fields, groupby, offset=0, limit=None, context=None, orderby=False): + def read_group(self, cr, uid, domain, fields, groupby, offset=0, limit=None, context={}, orderby=False): """ Get the list of records in list view grouped by the given ``groupby`` fields @@ -2252,91 +2252,80 @@ class BaseModel(object): * if user tries to bypass access rules for read on the requested object """ - context = context or {} self.check_access_rights(cr, uid, 'read') - if not fields: - fields = self._columns.keys() - query = self._where_calc(cr, uid, domain, context=context) - self._apply_ir_rules(cr, uid, query, 'read', context=context) - - # Take care of adding join(s) if groupby is an '_inherits'ed field - groupby_list = groupby - qualified_groupby_field = groupby - if groupby: - if isinstance(groupby, list): - groupby = groupby[0] - splitted_groupby = groupby.split(':') - if len(splitted_groupby) == 2: - groupby = splitted_groupby[0] - groupby_function = splitted_groupby[1] - else: - groupby_function = False - qualified_groupby_field = self._inherits_join_calc(groupby, query) - - if groupby: - assert not groupby or groupby in fields, "Fields in 'groupby' must appear in the list of fields to read (perhaps it's missing in the list view?)" - groupby_def = self._columns.get(groupby) or (self._inherit_fields.get(groupby) and self._inherit_fields.get(groupby)[2]) - assert groupby_def and groupby_def._classic_write, "Fields in 'groupby' must be regular database-persisted fields (no function or related fields), or function fields with store=True" - - # TODO it seems fields_get can be replaced by _all_columns (no need for translation) + # Step 0 : preparing some useful variables + fields = fields or self._columns.keys() fget = self.fields_get(cr, uid, fields) - group_by_params = {} - select_terms = [] - groupby_type = None - if groupby: - if fget.get(groupby): - groupby_type = fget[groupby]['type'] - if groupby_type in ('date', 'datetime'): - if groupby_function: - interval = groupby_function - else: - interval = 'month' + if isinstance(groupby, basestring): + groupby = [groupby] + split_groupby = groupby[0].split(':') if groupby else None + first_groupby = split_groupby[0] if split_groupby else None + groupby_function = split_groupby[1] if split_groupby and len(split_groupby) == 2 else None + interval = groupby_function if groupby_function else 'month' + groupby_type = fget[first_groupby]['type'] if first_groupby else None + if groupby_type in ('date', 'datetime'): + dt_format = DEFAULT_SERVER_DATETIME_FORMAT if groupby_type == 'datetime' else DEFAULT_SERVER_DATE_FORMAT + tz_convert = groupby_type == 'datetime' and context.get('tz') in pytz.all_timezones + time_display_format = { + 'day': 'dd MMM YYYY', + 'week': "'W'w YYYY", + 'month': 'MMMM YYYY', + 'quarter': 'QQQ YYYY', + 'year': 'YYYY'}[interval] + time_interval = { + 'day': dateutil.relativedelta.relativedelta(months=3), + 'week': datetime.timedelta(days=7), + 'month': dateutil.relativedelta.relativedelta(months=1), + 'quarter': dateutil.relativedelta.relativedelta(months=3), + 'year': dateutil.relativedelta.relativedelta(years=1)}[interval] - if interval == 'day': - display_format = 'dd MMM YYYY' - elif interval == 'week': - display_format = "'W'w YYYY" - elif interval == 'month': - display_format = 'MMMM YYYY' - elif interval == 'quarter': - display_format = 'QQQ YYYY' - elif interval == 'year': - display_format = 'YYYY' + query = self._where_calc(cr, uid, domain, context=context) - if groupby_type == 'datetime' and context.get('tz') in pytz.all_timezones: - # Convert groupby result to user TZ to avoid confusion! - # PostgreSQL is compatible with all pytz timezone names, so we can use them - # directly for conversion, starting with timestamps stored in UTC. - timezone = context.get('tz', 'UTC') - qualified_groupby_field = "timezone('%s', timezone('UTC',%s))" % (timezone, qualified_groupby_field) - qualified_groupby_field = "date_trunc('%s', %s)" % (interval, qualified_groupby_field) - elif groupby_type == 'boolean': - qualified_groupby_field = "coalesce(%s,false)" % qualified_groupby_field - select_terms.append("%s as %s " % (qualified_groupby_field, groupby)) - else: + # Step 1: security stuff + # add relevant ir_rules to the where clause, perform some basic sanity checks + self._apply_ir_rules(cr, uid, query, 'read', context=context) + if first_groupby: + assert first_groupby in fields, "Fields in 'groupby' must appear in the list of fields to read (perhaps it's missing in the list view?)" + groupby_def = self._columns.get(first_groupby) or (self._inherit_fields.get(first_groupby) and self._inherit_fields.get(first_groupby)[2]) + assert groupby_def and groupby_def._classic_write, "Fields in 'groupby' must be regular database-persisted fields (no function or related fields), or function fields with store=True" + if not (first_groupby in fget): # Don't allow arbitrary values, as this would be a SQL injection vector! raise except_orm(_('Invalid group_by'), - _('Invalid group_by specification: "%s".\nA group_by specification must be a list of valid fields.')%(groupby,)) + _('Invalid group_by specification: "%s".\nA group_by specification must be a list of valid fields.')%(first_groupby,)) + # Step 2: preparing the query: + # compute aggregated fields, format groupbys, adjust timezone, ... aggregated_fields = [ f for f in fields if f not in ('id', 'sequence', groupby) if fget[f]['type'] in ('integer', 'float') if (f in self._all_columns and getattr(self._all_columns[f].column, '_classic_write'))] - for f in aggregated_fields: - group_operator = fget[f].get('group_operator', 'sum') - qualified_field = self._inherits_join_calc(f, query) - select_terms.append("%s(%s) AS %s" % (group_operator, qualified_field, f)) - order = orderby or groupby or '' - groupby_terms, orderby_terms = self._read_group_prepare(order, aggregated_fields, groupby, qualified_groupby_field, query, groupby_type) + field_formatter = lambda f: (fget[f].get('group_operator', 'sum'), self._inherits_join_calc(f, query), f) + select_terms = ["%s(%s) AS %s" % field_formatter(f) for f in aggregated_fields] + + qualified_groupby_field = self._inherits_join_calc(first_groupby, query) if first_groupby else None + + if groupby_type in ('date', 'datetime'): + if tz_convert: + # Convert groupby result to user TZ to avoid confusion! + # PostgreSQL is compatible with all pytz timezone names, so we can use them + # directly for conversion, starting with timestamps stored in UTC. + timezone = context.get('tz', 'UTC') + qualified_groupby_field = "timezone('%s', timezone('UTC',%s))" % (timezone, qualified_groupby_field) + qualified_groupby_field = "date_trunc('%s', %s)" % (interval, qualified_groupby_field) + + if groupby_type == 'boolean': + qualified_groupby_field = "coalesce(%s,false)" % qualified_groupby_field + + if first_groupby: + select_terms.append("%s as %s " % (qualified_groupby_field, first_groupby)) + + order = orderby or first_groupby or '' + groupby_terms, orderby_terms = self._read_group_prepare(order, aggregated_fields, first_groupby, qualified_groupby_field, query, groupby_type) from_clause, where_clause, where_clause_params = query.get_sql() - if len(groupby_list) < 2 and context.get('group_by_no_leaf'): - count_field = '_' - else: - count_field = groupby prefix_terms = lambda prefix, terms: (prefix + " " + ",".join(terms)) if terms else '' prefix_term = lambda prefix, term: ('%s %s' % (prefix, term)) if term else '' @@ -2352,7 +2341,7 @@ class BaseModel(object): %(offset)s """ % { 'table': self._table, - 'count_field': count_field, + 'count_field': '_' if (len(groupby) < 2 and context.get('group_by_no_leaf')) else first_groupby, 'extra_fields': prefix_terms(',', select_terms), 'from': from_clause, 'where': prefix_term('WHERE', where_clause), @@ -2362,73 +2351,54 @@ class BaseModel(object): 'offset': prefix_term('OFFSET', int(offset) if limit else None), } cr.execute(query, where_clause_params) - alldata = {} - fetched_data = cr.dictfetchall() - data_ids = [] - for r in fetched_data: - for fld, val in r.items(): - if val is None: r[fld] = False - alldata[r['id']] = r - data_ids.append(r['id']) - del r['id'] + if not first_groupby: + return {r.pop('id'): r for r in cr.dictfetchall() } - if groupby: - data = self.read(cr, uid, data_ids, [groupby], context=context) - # restore order of the search as read() uses the default _order (this is only for groups, so the footprint of data should be small): - data_dict = dict((d['id'], d[groupby] ) for d in data) - result = [{'id': i, groupby: data_dict[i]} for i in data_ids] - else: - result = [{'id': i} for i in data_ids] + none_to_false = lambda record: {k: (False if v is None else v) for k,v in record.iteritems() } + fetched_data = map(none_to_false, cr.dictfetchall()) - for d in result: - if groupby: - d['__domain'] = [(groupby, '=', alldata[d['id']][groupby] or False)] + domain - if not isinstance(groupby_list, (str, unicode)): - if groupby or not context.get('group_by_no_leaf', False): - d['__context'] = {'group_by': groupby_list[1:]} - if groupby and groupby in fget: - groupby_type = fget[groupby]['type'] - if d[groupby] and groupby_type in ('date', 'datetime'): - groupby_datetime = alldata[d['id']][groupby] - if isinstance(groupby_datetime, basestring): - _default = datetime.datetime(1970, 1, 1) # force starts of month - groupby_datetime = dateutil.parser.parse(groupby_datetime, default=_default) - tz_convert = groupby_type == 'datetime' and context.get('tz') in pytz.all_timezones - if tz_convert: - groupby_datetime = pytz.timezone(context['tz']).localize(groupby_datetime) - d[groupby] = babel.dates.format_date( - groupby_datetime, format=display_format, locale=context.get('lang', 'en_US')) - domain_dt_begin = groupby_datetime - if interval == 'quarter': - domain_dt_end = groupby_datetime + dateutil.relativedelta.relativedelta(months=3) - elif interval == 'month': - domain_dt_end = groupby_datetime + dateutil.relativedelta.relativedelta(months=1) - elif interval == 'week': - domain_dt_end = groupby_datetime + datetime.timedelta(days=7) - elif interval == 'day': - domain_dt_end = groupby_datetime + datetime.timedelta(days=1) - else: - domain_dt_end = groupby_datetime + dateutil.relativedelta.relativedelta(years=1) - if tz_convert: - # the time boundaries were all computed in the apparent TZ of the user, - # so we need to convert them to UTC to have proper server-side values. - domain_dt_begin = domain_dt_begin.astimezone(pytz.utc) - domain_dt_end = domain_dt_end.astimezone(pytz.utc) - dt_format = DEFAULT_SERVER_DATETIME_FORMAT if groupby_type == 'datetime' else DEFAULT_SERVER_DATE_FORMAT - d['__domain'] = [(groupby, '>=', domain_dt_begin.strftime(dt_format)), - (groupby, '<', domain_dt_end.strftime(dt_format))] + domain - del alldata[d['id']][groupby] - d.update(alldata[d['id']]) - del d['id'] + data_ids = [r['id'] for r in fetched_data] + data_dict = {d['id']: d[first_groupby] for d in self.read(cr, uid, data_ids, [first_groupby], context=context)} + sorted_data = [{first_groupby: data_dict[id]} for id in data_ids] - if groupby and groupby in self._group_by_full: - result = self._read_group_fill_results(cr, uid, domain, groupby, groupby_list, + def format_result (fromquery, fromread): + result = { + '__domain': [(first_groupby, '=', fromquery[first_groupby] or False)] + domain, + '__context': {'group_by': groupby[1:]} + } + result.update(fromquery) + result.update(fromread) + + if groupby_type in ('date', 'datetime') and fromquery[first_groupby]: + groupby_datetime = fromquery[first_groupby] + if isinstance(groupby_datetime, basestring): + groupby_datetime = datetime.datetime.strptime(groupby_datetime, dt_format) + if tz_convert: + groupby_datetime = pytz.timezone(context['tz']).localize(groupby_datetime) + result[first_groupby] = babel.dates.format_date( + groupby_datetime, format=time_display_format, locale=context.get('lang', 'en_US')) + domain_dt_begin = groupby_datetime + domain_dt_end = groupby_datetime + time_interval + if tz_convert: + # the time boundaries were all computed in the apparent TZ of the user, + # so we need to convert them to UTC to have proper server-side values. + domain_dt_begin = domain_dt_begin.astimezone(pytz.utc) + domain_dt_end = domain_dt_end.astimezone(pytz.utc) + result['__domain'] = [(first_groupby, '>=', domain_dt_begin.strftime(dt_format)), + (first_groupby, '<', domain_dt_end.strftime(dt_format))] + domain + del result['id'] + return result + + result = map(format_result, fetched_data, sorted_data) + + if first_groupby in self._group_by_full: + result = self._read_group_fill_results(cr, uid, domain, first_groupby, groupby, aggregated_fields, result, read_group_order=order, context=context) - return result + def _inherits_join_add(self, current_model, parent_model_name, query): """ Add missing table SELECT and JOIN clause to ``query`` for reaching the parent table (no duplicates)