From 97034fb99849fbba12e2ef09491f85195bfae62d Mon Sep 17 00:00:00 2001 From: Markus Zolliker Date: Thu, 30 Jan 2020 10:24:40 +0100 Subject: [PATCH] implement SECoP proxy modules A proxy module is a module with a known structure, but accessed over a SECoP connection. For the configuration, a Frappy module class has to be given. The proxy class is created from this, but does not inherit from it. However, the class of the returned object will be subclass of the SECoP base classes (Readable, Drivable etc.). A possible extension might be, that instead of the Frappy class, the JSON module description can be given, as a separate file or directly in the config file. Or we might offer a tool to convert the JSON description to a python class. Change-Id: I9212d9f3fe82ec56dfc08611d0e1efc0b0112271 Reviewed-on: https://forge.frm2.tum.de/review/c/sine2020/secop/playground/+/22386 Tested-by: JenkinsCodeReview Reviewed-by: Markus Zolliker --- cfg/ppms.cfg | 4 +- cfg/ppms_proxy_test.cfg | 22 ++++ secop/__init__.py | 2 + secop/client/__init__.py | 202 ++++++++++++++++++---------------- secop/datatypes.py | 106 +++++++++++++++++- secop/errors.py | 18 ++- secop/modules.py | 13 ++- secop/params.py | 21 ++-- secop/parse.py | 3 +- secop/proxy.py | 230 +++++++++++++++++++++++++++++++++++++++ test/test_datatypes.py | 41 +++++++ test/test_modules.py | 2 +- 12 files changed, 550 insertions(+), 114 deletions(-) create mode 100644 cfg/ppms_proxy_test.cfg create mode 100644 secop/proxy.py diff --git a/cfg/ppms.cfg b/cfg/ppms.cfg index b48d4a8..0316d92 100644 --- a/cfg/ppms.cfg +++ b/cfg/ppms.cfg @@ -8,8 +8,8 @@ bindport = 5000 [module tt] class = secop_psi.ppms.Temp -.description = main temperature -.iodev = ppms +description = main temperature +iodev = ppms [module mf] class = secop_psi.ppms.Field diff --git a/cfg/ppms_proxy_test.cfg b/cfg/ppms_proxy_test.cfg new file mode 100644 index 0000000..bbaeea0 --- /dev/null +++ b/cfg/ppms_proxy_test.cfg @@ -0,0 +1,22 @@ +[node filtered.PPMS.psi.ch] +description = filtered PPMS at PSI + +[interface tcp] +type = tcp +bindto = 0.0.0.0 +bindport = 5002 + +[module secnode] +class = secop.SecNode +description = a SEC node +uri = tcp://localhost:5000 + +[module mf] +class = secop.Proxy +remote_class = secop_psi.ppms.Field +description = magnetic field +iodev = secnode +value.min = -0.1 +value.max = 0.1 +target.min = -8 +target.max = 8 diff --git a/secop/__init__.py b/secop/__init__.py index a120438..d90e39d 100644 --- a/secop/__init__.py +++ b/secop/__init__.py @@ -25,8 +25,10 @@ # allow to import the most important classes from 'secop' from secop.datatypes import * +from secop.lib.enum import Enum from secop.modules import Module, Readable, Writable, Drivable, Communicator, Attached from secop.params import Parameter, Command, Override from secop.metaclass import Done from secop.iohandler import IOHandler, IOHandlerBase from secop.stringio import StringIO, HasIodev +from secop.proxy import SecNode, Proxy, proxy_class diff --git a/secop/client/__init__.py b/secop/client/__init__.py index a535b93..9bfa570 100644 --- a/secop/client/__init__.py +++ b/secop/client/__init__.py @@ -64,8 +64,8 @@ class Logger: error = warning = critical = info -class CallbackMixin: - """abstract mixin +class CallbackObject: + """abstract definition for a target object for callbacks this is mainly for documentation, but it might be extended and used as a mixin for objects registered as a callback @@ -94,33 +94,115 @@ class CallbackMixin: """ -class SecopClient: +class ProxyClient: + """common functionality for proxy clients""" + + CALLBACK_NAMES = ('updateEvent', 'descriptiveDataChange', 'nodeStateChange', 'unhandledMessage') + online = False # connected or reconnecting since a short time + validate_data = False + _state = 'disconnected' # further possible values: 'connecting', 'reconnecting', 'connected' + + def __init__(self): + self.callbacks = {cbname: defaultdict(list) for cbname in self.CALLBACK_NAMES} + # caches (module, parameter) = value, timestamp, readerror (internal names!) + self.cache = {} + + def register(self, key, obj=None, **kwds): + """register callback functions + + - kwds each key must be a valid callback name defined in self.CALLBACK_NAMES + - kwds values are the callback functions + - if obj is not None, use its methods named from the callback name, if not given in kwds + - key might be either: + 1) None: general callback (all callbacks) + 2) : callbacks related to a module (not called for 'unhandledMessage') + 3) (, ): callback for specified parameter (only called for 'updateEvent') + """ + for cbname in self.CALLBACK_NAMES: + cbfunc = kwds.pop(cbname, None) + if obj and cbfunc is None: + cbfunc = getattr(obj, cbname, None) + if not cbfunc: + continue + cbdict = self.callbacks[cbname] + cbdict[key].append(cbfunc) + + # immediately call for some callback types + if cbname == 'updateEvent': + if key is None: + for (mname, pname), data in self.cache.items(): + cbfunc(mname, pname, *data) + else: + data = self.cache.get(key, None) + if data: + cbfunc(*key, *data) # case single parameter + else: # case key = module + for (mname, pname), data in self.cache.items(): + if mname == key: + cbfunc(mname, pname, *data) + elif cbname == 'nodeStateChange': + cbfunc(self.online, self._state) + if kwds: + raise TypeError('unknown callback: %s' % (', '.join(kwds))) + + def callback(self, key, cbname, *args): + """perform callbacks + + key=None: + key=: callbacks for specified module + key=(, is not compatible, i.e. there + exists a value which is valid for ourselfs, but not for + """ + raise NotImplementedError + class Stub(DataType): """incomplete datatype, to be replaced with a proper one later during module load @@ -182,6 +192,9 @@ class FloatRange(DataType): value = float(value) except Exception: raise BadValueError('Can not __call__ %r to float' % value) + if math.isinf(value): + raise BadValueError('FloatRange does not accept infinity') + prec = max(abs(value * self.relative_resolution), self.absolute_resolution) if self.min - prec <= value <= self.max + prec: return min(max(value, self.min), self.max) @@ -215,6 +228,12 @@ class FloatRange(DataType): return ' '.join([self.fmtstr % value, unit]) return self.fmtstr % value + def compatible(self, other): + if not isinstance(other, (FloatRange, ScaledInteger)): + raise BadValueError('incompatible datatypes') + # avoid infinity + other(max(sys.float_info.min, self.min)) + other(min(sys.float_info.max, self.max)) class IntRange(DataType): @@ -266,6 +285,15 @@ class IntRange(DataType): def format_value(self, value, unit=None): return '%d' % value + def compatible(self, other): + if isinstance(other, IntRange): + other(self.min) + other(self.max) + return + # this will accept some EnumType, BoolType + for i in range(self.min, self.max + 1): + other(i) + class ScaledInteger(DataType): """Scaled integer int type @@ -365,6 +393,12 @@ class ScaledInteger(DataType): return ' '.join([self.fmtstr % value, unit]) return self.fmtstr % value + def compatible(self, other): + if not isinstance(other, (FloatRange, ScaledInteger)): + raise BadValueError('incompatible datatypes') + other(self.min) + other(self.max) + class EnumType(DataType): @@ -408,6 +442,10 @@ class EnumType(DataType): def format_value(self, value, unit=None): return '%s<%s>' % (self._enum[value].name, self._enum[value].value) + def compatible(self, other): + for m in self._enum.members: + other(m) + class BLOBType(DataType): properties = { @@ -438,7 +476,7 @@ class BLOBType(DataType): def __call__(self, value): """return the validated (internal) value or raise""" if not isinstance(value, bytes): - raise BadValueError('%r has the wrong type!' % value) + raise BadValueError('%s has the wrong type!' % repr(value)) size = len(value) if size < self.minbytes: raise BadValueError( @@ -464,6 +502,13 @@ class BLOBType(DataType): def format_value(self, value, unit=None): return repr(value) + def compatible(self, other): + try: + if self.minbytes < other.minbytes or self.maxbytes > other.maxbytes: + raise BadValueError('incompatible datatypes') + except AttributeError: + raise BadValueError('incompatible datatypes') + class StringType(DataType): properties = { @@ -494,7 +539,7 @@ class StringType(DataType): def __call__(self, value): """return the validated (internal) value or raise""" if not isinstance(value, str): - raise BadValueError('%r has the wrong type!' % value) + raise BadValueError('%s has the wrong type!' % repr(value)) if not self.isUTF8: try: value.encode('ascii') @@ -527,6 +572,14 @@ class StringType(DataType): def format_value(self, value, unit=None): return repr(value) + def compatible(self, other): + try: + if self.minchars < other.minchars or self.maxchars > other.maxchars or \ + self.isUTF8 > other.isUTF8: + raise BadValueError('incompatible datatypes') + except AttributeError: + raise BadValueError('incompatible datatypes') + # TextType is a special StringType intended for longer texts (i.e. embedding \n), # whereas StringType is supposed to not contain '\n' @@ -578,6 +631,11 @@ class BoolType(DataType): def format_value(self, value, unit=None): return repr(bool(value)) + def compatible(self, other): + other(False) + other(True) + + Stub.fix_datatypes() # @@ -673,6 +731,14 @@ class ArrayOf(DataType): return ' '.join([res, unit]) return res + def compatible(self, other): + try: + if self.minlen < other.minlen or self.maxlen > other.maxlen: + raise BadValueError('incompatible datatypes') + self.members.compatible(other.members) + except AttributeError: + raise BadValueError('incompatible datatypes') + class TupleOf(DataType): @@ -729,6 +795,15 @@ class TupleOf(DataType): return '(%s)' % (', '.join([sub.format_value(elem) for sub, elem in zip(self.members, value)])) + def compatible(self, other): + if not isinstance(other, TupleOf): + raise BadValueError('incompatible datatypes') + if len(self.members) != len(other.members) : + raise BadValueError('incompatible datatypes') + for a, b in zip(self.members, other.members): + a.compatible(b) + + class StructOf(DataType): @@ -763,7 +838,7 @@ class StructOf(DataType): return res def __repr__(self): - opt = self.optional if self.optional else '' + opt = ', optional=%r' % self.optional if self.optional else '' return 'StructOf(%s%s)' % (', '.join( ['%s=%s' % (n, repr(st)) for n, st in list(self.members.items())]), opt) @@ -808,6 +883,17 @@ class StructOf(DataType): def format_value(self, value, unit=None): return '{%s}' % (', '.join(['%s=%s' % (k, self.members[k].format_value(v)) for k, v in sorted(value.items())])) + def compatible(self, other): + try: + mandatory = set(other.members) - set(other.optional) + for k, m in self.members.items(): + m.compatible(other.members[k]) + mandatory.discard(k) + if mandatory: + raise BadValueError('incompatible datatypes') + except (AttributeError, TypeError, KeyError): + raise BadValueError('incompatible datatypes') + class CommandType(DataType): IS_COMMAND = True @@ -858,6 +944,16 @@ class CommandType(DataType): # actually I have no idea what to do here! raise NotImplementedError + def compatible(self, other): + try: + if self.argument != other.argument: # not both are None + self.argument.compatible(other.argument) + if self.result != other.result: # not both are None + other.result.compatible(self.result) + except AttributeError: + raise BadValueError('incompatible datatypes') + + # internally used datatypes (i.e. only for programming the SEC-node) class DataTypeType(DataType): diff --git a/secop/errors.py b/secop/errors.py index e262425..3cbbda3 100644 --- a/secop/errors.py +++ b/secop/errors.py @@ -70,6 +70,11 @@ class NoSuchModuleError(SECoPError): name = 'NoSuchModule' +# pylint: disable=redefined-builtin +class NotImplementedError(NotImplementedError, SECoPError): + pass + + class NoSuchParameterError(SECoPError): pass @@ -122,6 +127,16 @@ class HardwareError(SECoPError): pass +def make_secop_error(name, text): + errcls = EXCEPTIONS.get(name, InternalError) + return errcls(text) + + +def secop_error(exception): + if isinstance(exception, SECoPError): + return exception + return InternalError(repr(exception)) + EXCEPTIONS = dict( NoSuchModule=NoSuchModuleError, @@ -137,8 +152,9 @@ EXCEPTIONS = dict( IsError=IsErrorError, Disabled=DisabledError, SyntaxError=ProtocolError, + NotImplementedError=NotImplementedError, InternalError=InternalError, -# internal short versions (candidates for spec) + # internal short versions (candidates for spec) Protocol=ProtocolError, Internal=InternalError, ) diff --git a/secop/modules.py b/secop/modules.py index 934f0fc..c971a64 100644 --- a/secop/modules.py +++ b/secop/modules.py @@ -140,6 +140,13 @@ class Module(HasProperties, metaclass=ModuleMeta): for aname, aobj in self.accessibles.items(): # make a copy of the Parameter/Command object aobj = aobj.copy() + if isinstance(aobj, Parameter): + # fix default properties poll and needscfg + if aobj.poll is None: + aobj.properties['poll'] = bool(aobj.handler) + if aobj.needscfg is None: + aobj.properties['needscfg'] = not aobj.poll + if aobj.export: if aobj.export is True: predefined_obj = PREDEFINED_ACCESSIBLES.get(aname, None) @@ -200,7 +207,7 @@ class Module(HasProperties, metaclass=ModuleMeta): self.writeDict[pname] = pobj.value else: if pobj.default is None: - if not pobj.poll: + if pobj.needscfg: raise ConfigError('Module %s: Parameter %r has no default ' 'value and was not given in config!' % (self.name, pname)) @@ -349,7 +356,7 @@ class Readable(Module): def startModule(self, started_callback): """start basic polling thread""" - if issubclass(self.pollerClass, BasicPoller): + if self.pollerClass and issubclass(self.pollerClass, BasicPoller): # use basic poller for legacy code mkthread(self.__pollThread, started_callback) else: @@ -479,4 +486,4 @@ class Attached(Property): super().__init__('attached module', StringType()) def __repr__(self): - return 'Attached(%r)' % self.description + return 'Attached(%s)' % (repr(self.attrname) if self.attrname else '') diff --git a/secop/params.py b/secop/params.py index 36b4955..4bd9541 100644 --- a/secop/params.py +++ b/secop/params.py @@ -83,10 +83,16 @@ class Parameter(Accessible): from the config file if specified there poll can be: - - False or 0 (never poll this parameter), this is the default - - True or 1 (poll this parameter) - - for any other integer, the meaning depends on the used poller - meaning for the default simple poller: + - None: will be converted to True/False if handler is/is not None + - False or 0 (never poll this parameter) + - True or > 0 (poll this parameter) + - the exact meaning depends on the used poller + meaning for secop.poller.Poller: + - 1 or True (AUTO), converted to SLOW (readonly=False), DYNAMIC('status' and 'value') or REGULAR(else) + - 2 (SLOW), polled with lower priority and a multiple of pollperiod + - 3 (REGULAR), polled with pollperiod + - 4 (DYNAMIC), polled with pollperiod, if not BUSY, else with a fraction of pollperiod + meaning for the basicPoller: - True or 1 (poll this every pollinterval) - positive int (poll every N(th) pollinterval) - negative int (normally poll every N(th) pollinterval, if module is busy, poll every pollinterval) @@ -110,7 +116,8 @@ class Parameter(Accessible): ValueType(), export=False, default=None, mandatory=False), 'export': Property('Is this parameter accessible via SECoP? (vs. internal parameter)', OrType(BoolType(), StringType()), export=False, default=True), - 'poll': Property('Polling indicator', IntRange(), export=False, default=False), + 'poll': Property('Polling indicator', NoneOr(IntRange()), export=False, default=None), + 'needscfg': Property('needs value in config', NoneOr(BoolType()), export=False, default=None), 'optional': Property('[Internal] is this parameter optional?', BoolType(), export=False, settable=False, default=False), 'handler': Property('[internal] overload the standard read and write functions', @@ -139,9 +146,6 @@ class Parameter(Accessible): datatype.setProperty('unit', unit) super(Parameter, self).__init__(**kwds) - if self.handler and not self.poll: - self.properties['poll'] = True - if self.readonly and self.initwrite: raise ProgrammingError('can not have both readonly and initwrite!') @@ -182,6 +186,7 @@ class UnusedClass: # do not derive anything from this! pass + class Parameters(OrderedDict): """class storage for Parameters""" def __init__(self, *args, **kwds): diff --git a/secop/parse.py b/secop/parse.py index ad3ad97..0ce2a8b 100644 --- a/secop/parse.py +++ b/secop/parse.py @@ -68,7 +68,7 @@ class Parser: def parse_string(self, orgtext): # handle quoted and unquoted strings correctly text = orgtext.strip() - if text[0] in ('"', u"'"): + if text[0] in ('"', "'"): # quoted string quote = text[0] idx = 0 @@ -160,7 +160,6 @@ class Parser: return self.parse_string(orgtext) def parse(self, orgtext): - print("parsing %r" % orgtext) res, rem = self.parse_sub(orgtext) if rem and rem[0] in ',;': return self.parse_sub('[%s]' % orgtext) diff --git a/secop/proxy.py b/secop/proxy.py new file mode 100644 index 0000000..1686a72 --- /dev/null +++ b/secop/proxy.py @@ -0,0 +1,230 @@ +# -*- coding: utf-8 -*- +# ***************************************************************************** +# +# This program is free software; you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the Free Software +# Foundation; either version 2 of the License, or (at your option) any later +# version. +# +# This program is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more +# details. +# +# You should have received a copy of the GNU General Public License along with +# this program; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +# +# Module authors: +# Markus Zolliker +# +# ***************************************************************************** +"""SECoP proxy modules""" + +from secop.lib import get_class +from secop.modules import Module, Writable, Readable, Drivable, Attached +from secop.datatypes import StringType +from secop.protocol.dispatcher import make_update +from secop.properties import Property +from secop.client import SecopClient, decode_msg, encode_msg_frame +from secop.params import Parameter, Command +from secop.errors import ConfigError, make_secop_error, secop_error + + + +class ProxyModule(Module): + properties = { + 'iodev': Attached(), + 'module': + Property('remote module name', datatype=StringType(), default=''), + } + + _consistency_check_done = False + _secnode = None + + def updateEvent(self, module, parameter, value, timestamp, readerror): + pobj = self.parameters[parameter] + pobj.timestamp = timestamp + # should be done here: deal with clock differences + if readerror: + readerror = make_secop_error(*readerror) + if not readerror: + try: + pobj.value = value # store the value even in case of a validation error + pobj.value = pobj.datatype(value) + except Exception as e: + readerror = secop_error(e) + pobj.readerror = readerror + self.DISPATCHER.broadcast_event(make_update(self.name, pobj)) + + def initModule(self): + if not self.module: + self.properties['module'] = self.name + self._secnode = self._iodev.secnode + self._secnode.register(self.module, self) + super().initModule() + + def descriptiveDataChange(self, module, moddesc): + if module is None: + return # do not care about the node for now + self._check_descriptive_data() + + def _check_descriptive_data(self): + params = self.parameters.copy() + cmds = self.commands.copy() + moddesc = self._secnode.modules[self.module] + remoteparams = moddesc['parameters'].copy() + remotecmds = moddesc['commands'].copy() + while params: + pname, pobj = params.popitem() + props = remoteparams.get(pname, None) + if props is None: + self.log.warning('remote parameter %s:%s does not exist' % (self.module, pname)) + continue + dt = props['datatype'] + try: + if pobj.readonly: + dt.compatible(pobj.datatype) + else: + if props['readonly']: + self.log.warning('remote parameter %s:%s is read only' % (self.module, pname)) + pobj.datatype.compatible(dt) + try: + dt.compatible(pobj.datatype) + except Exception: + self.log.warning('remote parameter %s:%s is not fully compatible: %r != %r' + % (self.module, pname, pobj.datatype, dt)) + except Exception: + self.log.warning('remote parameter %s:%s has an incompatible datatype: %r != %r' + % (self.module, pname, pobj.datatype, dt)) + while cmds: + cname, cobj = cmds.popitem() + props = remotecmds.get(cname) + if props is None: + self.log.warning('remote command %s:%s does not exist' % (self.module, cname)) + continue + dt = props['datatype'] + try: + cobj.datatype.compatible(dt) + except Exception: + self.log.warning('remote command %s:%s is not compatible: %r != %r' + % (self.module, pname, pobj.datatype, dt)) + # what to do if descriptive data does not match? + # we might raise an exception, but this would lead to a reconnection, + # which might not help. + # for now, the error message must be enough + + def nodeStateChange(self, online, state): + if online and not self._consistency_check_done: + self._check_descriptive_data() + self._consistency_check_done = True + + +class ProxyReadable(ProxyModule, Readable): + pass + + +class ProxyWritable(ProxyModule, Writable): + pass + + +class ProxyDrivable(ProxyModule, Drivable): + pass + + +PROXY_CLASSES = [ProxyDrivable, ProxyWritable, ProxyReadable, ProxyModule] + + +class SecNode(Module): + properties = { + 'uri': + Property('uri of a SEC node', datatype=StringType()), + } + commands = { + 'request': + Command('send a request', argument=StringType(), result=StringType()) + } + + def earlyInit(self): + self.secnode = SecopClient(self.uri, self.log) + self.secnode.register(None, self) # for nodeStateChange + + def startModule(self, started_callback): + self.secnode.spawn_connect(started_callback) + + def do_request(self, msg): + """for test purposes""" + reply = self.secnode.request(*decode_msg(msg.encode('utf-8'))) + return encode_msg_frame(*reply).decode('utf-8') + + +def proxy_class(remote_class, name=None): + """create a proxy class based on the definition of remote class + + remote class is . of a class used on the remote node + if name is not given, 'Proxy' + is used + """ + rcls = get_class(remote_class) + if name is None: + name = rcls.__name__ + + for proxycls in PROXY_CLASSES: + if issubclass(rcls, proxycls.__bases__[-1]): + # avoid 'should not be redefined' warning + proxycls.accessibles = {} + break + else: + raise ConfigError('%r is no SECoP module class' % remote_class) + + parameters = {} + commands = {} + attrs = dict(parameters=parameters, commands=commands, properties=rcls.properties) + + for aname, aobj in rcls.accessibles.items(): + if isinstance(aobj, Parameter): + pobj = aobj.copy() + parameters[aname] = pobj + pobj.properties['poll'] = False + pobj.properties['handler'] = None + pobj.properties['needscfg'] = False + + def rfunc(self, pname=aname): + value, _, readerror = self._secnode.getParameter(self.name, pname) + if readerror: + raise readerror + return value + + attrs['read_' + aname] = rfunc + + if not pobj.readonly: + + def wfunc(self, value, pname=aname): + value, _, readerror = self._secnode.setParameter(self.name, pname, value) + if readerror: + raise make_secop_error(*readerror) + return value + + attrs['write_' + aname] = wfunc + + elif isinstance(aobj, Command): + cobj = aobj.copy() + commands[aname] = cobj + + def cfunc(self, arg=None, cname=aname): + return self._secnode.execCommand(self.name, cname, arg) + + attrs['do_' + aname] = cfunc + + else: + raise ConfigError('do not now about %r in %s.accessibles' % (aobj, remote_class)) + + return type(name, (proxycls,), attrs) + + +def Proxy(name, logger, cfgdict, srv): + """create a Proxy object based on remote_class + + title cased as it acts like a class + """ + remote_class = cfgdict.pop('remote_class') + return proxy_class(remote_class)(name, logger, cfgdict, srv) diff --git a/test/test_datatypes.py b/test/test_datatypes.py index 3cc9b11..7f56890 100644 --- a/test/test_datatypes.py +++ b/test/test_datatypes.py @@ -609,3 +609,44 @@ def test_get_datatype(): get_datatype({'type': 'struct', 'members': {}}) with pytest.raises(ValueError): get_datatype({'type': 'struct', 'members':[1,2,3]}) + + +@pytest.mark.parametrize('dt, contained_in', [ + (FloatRange(-10, 10), FloatRange()), + (IntRange(-10, 10), FloatRange()), + (IntRange(-10, 10), IntRange(-20, 10)), + (StringType(), StringType(isUTF8=True)), + (StringType(10, 10), StringType()), + (ArrayOf(StringType(), 3, 5), ArrayOf(StringType(), 3, 6)), + (TupleOf(StringType(), BoolType()), TupleOf(StringType(), IntRange())), + (StructOf(a=FloatRange(-1,1)), StructOf(a=FloatRange(), b=BoolType(), optional=['b'])), +]) +def test_oneway_compatible(dt, contained_in): + dt.compatible(contained_in) + with pytest.raises(ValueError): + contained_in.compatible(dt) + +@pytest.mark.parametrize('dt1, dt2', [ + (FloatRange(-5.5, 5.5), ScaledInteger(10, -5.5, 5.5)), + (IntRange(0,1), BoolType()), + (IntRange(-10, 10), IntRange(-10, 10)), +]) +def test_twoway_compatible(dt1, dt2): + dt1.compatible(dt1) + dt2.compatible(dt2) + +@pytest.mark.parametrize('dt1, dt2', [ + (StringType(), FloatRange()), + (IntRange(-10, 10), StringType()), + (StructOf(a=BoolType(), b=BoolType()), ArrayOf(StringType(), 2)), + (ArrayOf(BoolType(), 2), TupleOf(BoolType(), StringType())), + (TupleOf(BoolType(), BoolType()), StructOf(a=BoolType(), b=BoolType())), + (ArrayOf(StringType(), 3), ArrayOf(BoolType(), 3)), + (TupleOf(StringType(), BoolType()), TupleOf(BoolType(), BoolType())), + (StructOf(a=FloatRange(-1, 1), b=StringType()), StructOf(a=FloatRange(), b=BoolType())), +]) +def test_incompatible(dt1, dt2): + with pytest.raises(ValueError): + dt1.compatible(dt2) + with pytest.raises(ValueError): + dt2.compatible(dt1) diff --git a/test/test_modules.py b/test/test_modules.py index 4dbbf86..3c96779 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -203,7 +203,7 @@ def test_ModuleMeta(): assert set(cfg['value'].keys()) == {'group', 'export', 'relative_resolution', 'visibility', 'unit', 'default', 'datatype', 'fmtstr', 'absolute_resolution', 'poll', 'max', 'min', 'readonly', 'constant', - 'description'} + 'description', 'needscfg'} # check on the level of classes # this checks Newclass1 too, as it is inherited by Newclass2