diff --git a/frappy/modules.py b/frappy/modules.py index 1c5aa86..171da21 100644 --- a/frappy/modules.py +++ b/frappy/modules.py @@ -31,7 +31,7 @@ from frappy.datatypes import ArrayOf, BoolType, EnumType, FloatRange, \ IntRange, StatusType, StringType, TextType, TupleOf, DiscouragedConversion, \ NoneOr from frappy.errors import BadValueError, CommunicationFailedError, ConfigError, \ - ProgrammingError, SECoPError, secop_error + ProgrammingError, SECoPError, secop_error, RangeError from frappy.lib import formatException, mkthread, UniqueObject from frappy.lib.enum import Enum from frappy.params import Accessible, Command, Parameter @@ -150,18 +150,35 @@ class HasAccessibles(HasProperties): new_rfunc.__module__ = cls.__module__ cls.wrappedAttributes[rname] = new_rfunc + cname = 'check_' + pname + for postfix in ('_limits', '_min', '_max'): + limname = pname + postfix + if limname in accessibles: + # find the base class, where the parameter is defined first. + # we have to check all bases, as they may not be treated yet when + # not inheriting from HasAccessibles + base = next(b for b in reversed(base.__mro__) if limname in b.__dict__) + if cname not in base.__dict__: + # there is no check method yet at this class + # add check function to the class where the limit was defined + setattr(base, cname, lambda self, value, pname=pname: self.checkLimits(value, pname)) + + cfuncs = tuple(filter(None, (b.__dict__.get(cname) for b in cls.__mro__))) wname = 'write_' + pname wfunc = getattr(cls, wname, None) if wfunc: # allow write method even when parameter is readonly, but internally writable - def new_wfunc(self, value, pname=pname, wfunc=wfunc): + def new_wfunc(self, value, pname=pname, wfunc=wfunc, check_funcs=cfuncs): with self.accessLock: pobj = self.accessibles[pname] - self.log.debug('validate %r for %r', value, pname) + self.log.debug('convert %r to datatype of %r', value, pname) # we do not need to handle errors here, we do not # want to make a parameter invalid, when a write failed new_value = pobj.datatype(value) + for c in check_funcs: + if c(self, value): + break new_value = wfunc(self, new_value) self.log.debug('write_%s(%r) returned %r', pname, value, new_value) if new_value is Done: # TODO: to be removed when all code using Done is updated @@ -175,7 +192,11 @@ class HasAccessibles(HasProperties): new_wfunc = None else: - def new_wfunc(self, value, pname=pname): + def new_wfunc(self, value, pname=pname, check_funcs=cfuncs): + value = self.accessibles[pname].datatype(value) + for c in check_funcs: + if c(self, value): + break setattr(self, pname, value) return value @@ -418,8 +439,27 @@ class Module(HasAccessibles): self.errorCallbacks[pname] = [] if not pobj.hasDatatype(): - errors.append('%s needs a datatype' % pname) - continue + head, _, postfix = pname.rpartition('_') + if postfix not in ('min', 'max', 'limits'): + errors.append('%s needs a datatype' % pname) + continue + # when datatype is not given, properties are set automagically + pobj.setProperty('readonly', False) + baseparam = self.parameters.get(head) + if not baseparam: + errors.append('parameter %r is given, but not %r' % (pname, head)) + continue + dt = baseparam.datatype + if dt is None: + continue # an error will be reported on baseparam + if postfix == 'limits': + pobj.setProperty('datatype', TupleOf(dt, dt)) + pobj.setProperty('default', (dt.min, dt.max)) + else: + pobj.setProperty('datatype', dt) + pobj.setProperty('default', getattr(dt, postfix)) + if not pobj.description: + pobj.setProperty('description', 'limit for %s' % pname) if pobj.value is None: if pobj.needscfg: @@ -761,6 +801,28 @@ class Module(HasAccessibles): raise ValueError('remote handler not found') self.remoteLogHandler.set_conn_level(self, conn, level) + def checkLimits(self, value, parametername='target'): + """check for limits + + :param value: the value to be checked for _min <= value <= _max + :param parametername: parameter name, default is 'target' + + raises RangeError in case the value is not valid + + This method is called automatically and needs therefore rarely to be + called by the programmer. It might be used in a check_ method, + when no automatic super call is desired. + """ + try: + min_, max_ = getattr(self, parametername + '_limits') + except AttributeError: + min_ = getattr(self, parametername + '_min', float('-inf')) + max_ = getattr(self, parametername + '_max', float('inf')) + if not min_ <= value <= max_: + if min_ > max_: + raise RangeError('invalid limits: [%g, %g]' % (min_, max_)) + raise RangeError('limits violation: %g outside [%g, %g]' % (value, min_, max_)) + class Readable(Module): """basic readable module""" diff --git a/test/test_modules.py b/test/test_modules.py index 65e2129..1b05137 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -27,7 +27,7 @@ import threading import pytest from frappy.datatypes import BoolType, FloatRange, StringType, IntRange, ScaledInteger -from frappy.errors import ProgrammingError, ConfigError +from frappy.errors import ProgrammingError, ConfigError, RangeError from frappy.modules import Communicator, Drivable, Readable, Module from frappy.params import Command, Parameter from frappy.rwhandler import ReadHandler, WriteHandler, nopoll @@ -783,3 +783,122 @@ def test_omit_unchanged_within(): mod2 = Mod('mod2', LoggerStub(), {'description': '', 'omit_unchanged_within': 0.125}, srv) assert mod2.parameters['a'].omit_unchanged_within == 0.125 + + +stdlim = { + 'a_min': -1, 'a_max': 2, + 'b_min': 0, + 'c_max': 10, + 'd_limits': (-1, 1), +} + + +class Lim(Module): + a = Parameter('', FloatRange(-10, 10), readonly=False, default=0) + a_min = Parameter() + a_max = Parameter() + + b = Parameter('', FloatRange(0, None), readonly=False, default=0) + b_min = Parameter() + + c = Parameter('', IntRange(None, 100), readonly=False, default=0) + c_max = Parameter() + + d = Parameter('', FloatRange(-5, 5), readonly=False, default=0) + d_limits = Parameter() + + e = Parameter('', IntRange(0, 8), readonly=False, default=0) + + def check_e(self, value): + if value % 2: + raise RangeError('e must not be odd') + + +def test_limit_defaults(): + + srv = ServerStub({}) + + mod = Lim('mod', LoggerStub(), {'description': 'test'}, srv) + + assert mod.a_min == -10 + assert mod.a_max == 10 + assert isinstance(mod.a_min, float) + assert isinstance(mod.a_max, float) + + assert mod.b_min == 0 + assert isinstance(mod.b_min, float) + + assert mod.c_max == 100 + assert isinstance(mod.c_max, int) + + assert mod.d_limits == (-5, 5) + assert isinstance(mod.d_limits[0], float) + assert isinstance(mod.d_limits[1], float) + + +@pytest.mark.parametrize('limits, pname, good, bad', [ + (stdlim, 'a', [-1, 2, 0], [-2, 3]), + (stdlim, 'b', [0, 1e99], [-1, -1e99]), + (stdlim, 'c', [-999, 0, 10], [11, 999]), + (stdlim, 'd', [-1, 0.1, 1], [-1.001, 1.001]), + ({'a_min': 0, 'a_max': -1}, 'a', [], [0, -1]), + (stdlim, 'e', [0, 2, 4, 6, 8], [-1, 1, 7, 9]), +]) +def test_limits(limits, pname, good, bad): + + srv = ServerStub({}) + + mod = Lim('mod', LoggerStub(), {'description': 'test'}, srv) + mod.check_a = 0 # this should not harm. check_a is never called on the instance + + for k, v in limits.items(): + setattr(mod, k, v) + + for v in good: + getattr(mod, 'write_' + pname)(v) + for v in bad: + with pytest.raises(RangeError): + getattr(mod, 'write_' + pname)(v) + + +def test_limit_inheritance(): + srv = ServerStub({}) + + class Base(Module): + a = Parameter('', FloatRange(), readonly=False, default=0) + + def check_a(self, value): + if int(value * 4) != value * 4: + raise ValueError('value is not a multiple of 0.25') + + class Mixin: + a_min = Parameter() + a_max = Parameter() + + class Mod(Mixin, Base): + def check_a(self, value): + if value == 0: + raise ValueError('value must not be 0') + + mod = Mod('mod', LoggerStub(), {'description': 'test', 'a_min': {'value': -1}, 'a_max': {'value': 1}}, srv) + + for good in [-1, -0.75, 0.25, 1]: + mod.write_a(good) + + for bad in [-2, -0.1, 0, 0.9, 1.1]: + with pytest.raises(ValueError): + mod.write_a(bad) + + class Mod2(Mixin, Base): + def check_a(self, value): + if value == 0: + raise ValueError('value must not be 0') + return True # indicates stop checking + + mod2 = Mod2('mod2', LoggerStub(), {'description': 'test', 'a_min': {'value': -1}, 'a_max': {'value': 1}}, srv) + + for good in [-2, -1, -0.75, 0.25, 1, 1.1]: + mod2.write_a(good) + + with pytest.raises(ValueError): + mod2.write_a(0)