diff --git a/openerp/addons/base/ir/ir_filters.py b/openerp/addons/base/ir/ir_filters.py index f716a4f8d43..40ef7d41a86 100644 --- a/openerp/addons/base/ir/ir_filters.py +++ b/openerp/addons/base/ir/ir_filters.py @@ -19,6 +19,7 @@ # ############################################################################## +from openerp import exceptions from osv import osv, fields from tools.translate import _ @@ -49,6 +50,37 @@ class ir_filters(osv.osv): my_acts = self.read(cr, uid, act_ids, ['name', 'is_default', 'domain', 'context', 'user_id']) return my_acts + def _check_global_default(self, cr, uid, vals, matching_filters, context=None): + """ _check_global_default(cursor, UID, dict, list(dict), dict) -> None + + Checks if there is a global default for the model_id requested. + + If there is, and the default is different than the record being written + (-> we're not updating the current global default), raise an error + to avoid users unknowingly overwriting existing global defaults (they + have to explicitly remove the current default before setting a new one) + + This method should only be called if ``vals`` is trying to set + ``is_default`` + + :raises openerp.exceptions.Warning: if there is an existing default and + we're not updating it + """ + existing_default = self.search(cr, uid, [ + ('model_id', '=', vals['model_id']), + ('user_id', '=', False), + ('is_default', '=', True)], context=context) + + if not existing_default: return + if matching_filters and \ + (matching_filters[0]['id'] == existing_default[0]): + return + + raise exceptions.Warning( + _("There is already a global filter set as default for %(model)s, delete or change it before setting a new default") % { + 'model': vals['model_id'] + }) + def create_or_replace(self, cr, uid, vals, context=None): lower_name = vals['name'].lower() matching_filters = [f for f in self.get_filters(cr, uid, vals['model_id']) @@ -58,17 +90,23 @@ class ir_filters(osv.osv): # or f.user_id.id == vals.user_id if (f['user_id'] and f['user_id'][0]) == vals.get('user_id', False)] - if 'user_id' in vals and vals.get('is_default'): - act_ids = self.search(cr, uid, [('model_id', '=', vals['model_id']), - ('user_id', '=', vals['user_id'])], - context=context) - self.write(cr, uid, act_ids, {'is_default': False}, context=context) + if vals.get('is_default'): + if vals.get('user_id'): + act_ids = self.search(cr, uid, [('model_id', '=', vals['model_id']), + ('user_id', '=', vals['user_id'])], + context=context) + self.write(cr, uid, act_ids, {'is_default': False}, context=context) + else: + self._check_global_default( + cr, uid, vals, matching_filters, context=None) + # When a filter exists for the same (name, model, user) triple, we simply # replace its definition. if matching_filters: self.write(cr, uid, matching_filters[0]['id'], vals, context) return matching_filters[0]['id'] + return self.create(cr, uid, vals, context) _sql_constraints = [ diff --git a/openerp/tests/test_ir_filters.py b/openerp/tests/test_ir_filters.py index 98714d7f46b..eaf5eb93d3c 100644 --- a/openerp/tests/test_ir_filters.py +++ b/openerp/tests/test_ir_filters.py @@ -2,7 +2,8 @@ import functools import openerp -import common +from openerp import exceptions +from . import common class Fixtures(object): def __init__(self, *args): @@ -173,3 +174,67 @@ class TestOwnDefaults(common.TransactionCase): dict(name='a', user_id=self.USER, is_default=True, domain='[]', context='{}'), dict(name='b', user_id=self.USER, is_default=False, domain='[]', context='{}'), ]) + +class TestGlobalDefaults(common.TransactionCase): + USER_ID = 3 + + @fixtures( + ('ir.filters', dict(name='a', user_id=False, model_id='ir.filters')), + ('ir.filters', dict(name='b', user_id=False, model_id='ir.filters')), + ) + def test_new_filter_not_default(self): + """ + When creating a @is_default filter with existing non-default filters, + the new filter gets the flag + """ + Filters = self.registry('ir.filters') + Filters.create_or_replace(self.cr, self.USER_ID, { + 'name': 'c', + 'model_id': 'ir.filters', + 'user_id': False, + 'is_default': True, + }) + filters = Filters.get_filters(self.cr, self.USER_ID, 'ir.filters') + + self.assertItemsEqual(map(noid, filters), [ + dict(name='a', user_id=False, is_default=False, domain='[]', context='{}'), + dict(name='b', user_id=False, is_default=False, domain='[]', context='{}'), + dict(name='c', user_id=False, is_default=True, domain='[]', context='{}'), + ]) + + @fixtures( + ('ir.filters', dict(name='a', user_id=False, model_id='ir.filters')), + ('ir.filters', dict(name='b', is_default=True, user_id=False, model_id='ir.filters')), + ) + def test_new_filter_existing_default(self): + """ + When creating a @is_default filter where an existing filter is already + @is_default, an error should be generated + """ + Filters = self.registry('ir.filters') + with self.assertRaises(exceptions.Warning): + Filters.create_or_replace(self.cr, self.USER_ID, { + 'name': 'c', + 'model_id': 'ir.filters', + 'user_id': False, + 'is_default': True, + }) + + @fixtures( + ('ir.filters', dict(name='a', user_id=False, model_id='ir.filters')), + ('ir.filters', dict(name='b', is_default=True, user_id=False, model_id='ir.filters')), + ) + def test_update_filter_set_default(self): + """ + When updating an existing filter to @is_default, if an other filter + already has the flag an error should be generated + """ + Filters = self.registry('ir.filters') + + with self.assertRaises(exceptions.Warning): + Filters.create_or_replace(self.cr, self.USER_ID, { + 'name': 'a', + 'model_id': 'ir.filters', + 'user_id': False, + 'is_default': True, + })