From 6c23c63afe714e384dc4c8caab300857486a0691 Mon Sep 17 00:00:00 2001 From: Olivier Dony Date: Tue, 30 Mar 2010 19:28:06 +0200 Subject: [PATCH] [IMP] refactoring of check_access_rule to avoid SQL injection and simplify code bzr revid: odo@openerp.com-20100330172806-p1zkvrmupw5zosai --- bin/osv/orm.py | 70 +++++++++++++++++++++++--------------------------- 1 file changed, 32 insertions(+), 38 deletions(-) diff --git a/bin/osv/orm.py b/bin/osv/orm.py index 697eb01f4e2..32efe32558b 100644 --- a/bin/osv/orm.py +++ b/bin/osv/orm.py @@ -2572,17 +2572,16 @@ class orm(orm_template): return '"%s"' % (f,) fields_pre2 = map(convert_field, fields_pre) order_by = self._parent_order or self._order - for i in range(0, len(ids), cr.IN_MAX): - sub_ids = ids[i:i+cr.IN_MAX] + for sub_ids in cr.split_for_in_conditions(ids): if d1: - cr.execute('SELECT %s FROM %s WHERE %s.id = ANY (%%s) AND %s ORDER BY %s' % \ + cr.execute('SELECT %s FROM %s WHERE %s.id IN %%s AND %s ORDER BY %s' % \ (','.join(fields_pre2 + [self._table + '.id']), ','.join(tables), self._table, ' and '.join(d1), order_by),[sub_ids,]+d2) - if not cr.rowcount == len({}.fromkeys(sub_ids)): + if cr.rowcount != len(sub_ids): raise except_orm(_('AccessError'), _('You try to bypass an access rule while reading (Document type: %s).') % self._description) else: - cr.execute('SELECT %s FROM \"%s\" WHERE id = ANY (%%s) ORDER BY %s' % + cr.execute('SELECT %s FROM \"%s\" WHERE id IN %%s ORDER BY %s' % (','.join(fields_pre2 + ['id']), self._table, order_by), (sub_ids,)) res.extend(cr.dictfetchall()) @@ -2734,29 +2733,24 @@ class orm(orm_template): if res and res[0]: raise except_orm('ConcurrencyException', _('Records were modified in the meanwhile')) - def check_access_rule(self, cr, uid, ids, mode, context=None): - d1, d2, tables = self.pool.get('ir.rule').domain_get(cr, uid, self._name, mode, context=context) - if d1: - d1 = ' and '+' and '.join(d1) - - for i in range(0, len(ids), cr.IN_MAX): - sub_ids = ids[i:i+cr.IN_MAX] - if d1: - cr.execute('SELECT '+self._table+'.id FROM '+','.join(tables)+' ' \ - 'WHERE '+self._table+'.id IN %s'+d1, (tuple(sub_ids),d2)) - if not cr.rowcount == len(sub_ids): + def check_access_rule(self, cr, uid, ids, operation, context=None): + """Verifies that the operation given by ``operation`` is allowed for the user + according to ir.rules. + @param ``operation``: one of ``'read'``, ``'write'``, ``'unlink'`` + @raise ``except_orm``: if current ir.rules do not permit this operation. + @return: ``None`` if the operation is allowed + """ + where_clause, where_params, tables = self.pool.get('ir.rule').domain_get(cr, uid, self._name, operation, context=context) + if where_clause: + where_clause = ' and ' + ' and '.join(where_clause) + for sub_ids in cr.split_for_in_conditions(ids): + cr.execute('SELECT ' + self._table + '.id FROM ' + ','.join(tables) + + ' WHERE ' + self._table + '.id IN %s' + where_clause, + [sub_ids] + where_params) + if cr.rowcount != len(sub_ids): raise except_orm(_('AccessError'), - _('You try to bypass an access rule to '+mode+ - ' (Document type: %s).') % self._name) - else: - cr.execute('SELECT id FROM "'+self._table+'" WHERE id IN %s', - (tuple(sub_ids),)) - if not cr.rowcount == len(sub_ids): - raise except_orm(_('AccessError'), - _('You try to ' +mode+ ' a record that doesn\'t exist (Document type: %s).') - % self._name) - #TODO: this is a SQL injection pattern again, need to refactor it - return ','.join(map(str,ids)) + _('Operation prohibited by access rules (Operation: %s, Document type: %s).') + % (operation, self._name)) def unlink(self, cr, uid, ids, context=None): if not ids: @@ -2787,12 +2781,13 @@ class orm(orm_template): # ids2 = [x[self._inherits[key]] for x in res] # self.pool.get(key).unlink(cr, uid, ids2) - ids_str = self.check_access_rule(cr, uid, ids, 'unlink', context=context) - cr.execute('delete from '+self._table+' ' \ - 'where id in ('+ids_str+')', ids) + self.check_access_rule(cr, uid, ids, 'unlink', context=context) + for sub_ids in cr.split_for_in_conditions(ids): + cr.execute('delete from ' + self._table + ' ' \ + 'where id in %s', sub_ids) for order, object, store_ids, fields in result_store: - if object<>self._name: + if object != self._name: obj = self.pool.get(object) cr.execute('select id from '+obj._table+' where id in ('+','.join(map(str, store_ids))+')') rids = map(lambda x: x[0], cr.fetchall()) @@ -2888,9 +2883,10 @@ class orm(orm_template): upd1.append(user) if len(upd0): - ids_str = self.check_access_rule(cr, user, ids, 'write', context=context) - cr.execute('update '+self._table+' set '+string.join(upd0, ',')+' ' \ - 'where id in ('+ids_str+')', upd1) + self.check_access_rule(cr, user, ids, 'write', context=context) + for sub_ids in cr.split_for_in_conditions(ids): + cr.execute('update ' + self._table + ' set ' + ','.join(upd0) + ' ' \ + 'where id in %s', upd1 + [sub_ids]) if totranslate: # TODO: optimize @@ -2921,11 +2917,9 @@ class orm(orm_template): for table in self._inherits: col = self._inherits[table] nids = [] - for i in range(0, len(ids), cr.IN_MAX): - sub_ids = ids[i:i+cr.IN_MAX] - ids_str = string.join(map(str, sub_ids), ',') + for sub_ids in cr.split_for_in_conditions(ids): cr.execute('select distinct "'+col+'" from "'+self._table+'" ' \ - 'where id in ('+ids_str+')', upd1) + 'where id in %s', (sub_ids,)) nids.extend([x[0] for x in cr.fetchall()]) v = {}