diff --git a/bin/addons/base/ir/ir_rule.py b/bin/addons/base/ir/ir_rule.py index ba508b48920..23981458606 100644 --- a/bin/addons/base/ir/ir_rule.py +++ b/bin/addons/base/ir/ir_rule.py @@ -21,11 +21,14 @@ from osv import fields,osv import time +from operator import itemgetter +from functools import partial import tools from tools.safe_eval import safe_eval as eval class ir_rule(osv.osv): _name = 'ir.rule' + _MODES = ['read', 'write', 'create', 'unlink'] def _domain_force_get(self, cr, uid, ids, field_name, arg, context={}): res = {} @@ -45,13 +48,7 @@ class ir_rule(osv.osv): return res def _check_model_obj(self, cr, uid, ids, context={}): - model_obj = self.pool.get('ir.model') - for rule in self.browse(cr, uid, ids, context): - model = model_obj.browse(cr, uid, rule.model_id.id, context).model - obj = self.pool.get(model) - if isinstance(obj, osv.osv_memory): - return False - return True + return not any(isinstance(self.pool.get(rule.model_id.model), osv.osv_memory) for rule in self.browse(cr, uid, ids, context)) _columns = { 'name': fields.char('Name', size=128, select=1), @@ -90,6 +87,8 @@ class ir_rule(osv.osv): @tools.cache() def _compute_domain(self, cr, uid, model_name, mode="read"): + if mode not in self._MODES: + raise ValueError('Invalid mode: %r' % (mode,)) group_rule = {} global_rules = [] @@ -115,6 +114,24 @@ class ir_rule(osv.osv): dom += self.domain_create(cr, uid, value) return dom + def clear_cache(self, cr, uid): + cr.execute("""SELECT DISTINCT m.model + FROM ir_rule r + JOIN ir_model m + ON r.model_id = m.id + WHERE r.global + OR EXISTS (SELECT 1 + FROM rule_group_rel g_rel + JOIN res_groups_users_rel u_rel + ON g_rel.group_id = u_rel.gid + WHERE g_rel.rule_group_id = r.id + AND u_rel.uid = %s) + """, (uid,)) + models = map(itemgetter(0), cr.fetchall()) + clear = partial(self._compute_domain.clear_cache, cr.dbname, uid) + [clear(model, mode) for model in models for mode in self._MODES] + + def domain_get(self, cr, uid, model_name, mode='read', context={}): dom = self._compute_domain(cr, uid, model_name, mode=mode) if dom: diff --git a/bin/addons/base/res/res_company.py b/bin/addons/base/res/res_company.py index c076eef612a..4b6ec1dc00a 100644 --- a/bin/addons/base/res/res_company.py +++ b/bin/addons/base/res/res_company.py @@ -133,12 +133,12 @@ class res_company(osv.osv): ids = self._get_company_children(cr, uid, company) return ids + @tools.cache() def _get_company_children(self, cr, uid=None, company=None): if not company: return [] ids = self.search(cr, uid, [('parent_id','child_of',[company])]) return ids - _get_company_children = tools.cache()(_get_company_children) def _get_partner_hierarchy(self, cr, uid, company_id, context={}): if company_id: @@ -168,7 +168,6 @@ class res_company(osv.osv): def write(self, cr, *args, **argv): self.cache_restart(cr) - # Restart the cache on the company_get method return super(res_company, self).write(cr, *args, **argv) def _get_euro(self, cr, uid, context={}): diff --git a/bin/addons/base/res/res_user.py b/bin/addons/base/res/res_user.py index 9040f90be95..c71982527d8 100644 --- a/bin/addons/base/res/res_user.py +++ b/bin/addons/base/res/res_user.py @@ -20,9 +20,9 @@ ############################################################################## from osv import fields,osv -from osv.orm import except_orm, browse_record +from osv.orm import browse_record import tools -import operator +from functools import partial import pytz import pooler from tools.translate import _ @@ -55,7 +55,6 @@ class groups(osv.osv): raise osv.except_osv(_('Error'), _('The name of the group can not start with "-"')) res = super(groups, self).write(cr, uid, ids, vals, context=context) - # Restart the cache on the company_get method self.pool.get('ir.model.access').call_cache_clearing_methods(cr) return res @@ -332,11 +331,14 @@ class users(osv.osv): 'groups_id': _get_group, 'address_id': False, } + + @tools.cache() def company_get(self, cr, uid, uid2, context=None): return self._get_company(cr, uid, context=context, uid2=uid2) - company_get = tools.cache()(company_get) def write(self, cr, uid, ids, values, context=None): + if not hasattr(ids, '__iter__'): + ids = [ids] if ids == [uid]: for key in values.keys(): if not (key in ('view', 'password','signature','action_id', 'company_id') or key.startswith('context_')): @@ -344,9 +346,13 @@ class users(osv.osv): else: uid = 1 res = super(users, self).write(cr, uid, ids, values, context=context) + + # clear caches linked to the users self.company_get.clear_cache(cr.dbname) - # Restart the cache on the company_get method self.pool.get('ir.model.access').call_cache_clearing_methods(cr) + clear = partial(self.pool.get('ir.rule').clear_cache, cr) + map(clear, ids) + return res def unlink(self, cr, uid, ids, context=None):