improve handling of module init methods

- complain when super call is omitted (this is a common programming
  error in Mixins)
- redesign waiting mechanism for startup

+ rename MultiEvent method 'setfunc' to 'get_trigger'

Change-Id: Ica27a75597321f2571a604a7a55448cffb1bec5e
Reviewed-on: https://forge.frm2.tum.de/review/c/sine2020/secop/playground/+/27369
Tested-by: Jenkins Automated Tests <pedersen+jenkins@frm2.tum.de>
Reviewed-by: Enrico Faulhaber <enrico.faulhaber@frm2.tum.de>
Reviewed-by: Markus Zolliker <markus.zolliker@psi.ch>
This commit is contained in:
zolliker 2021-12-22 15:19:21 +01:00
parent f13e29aad2
commit 8f7fb1e45b
14 changed files with 94 additions and 78 deletions

View File

@ -110,7 +110,7 @@ class MultiEvent(threading.Event):
def waiting_for(self): def waiting_for(self):
return set(event.name for event in self.events) return set(event.name for event in self.events)
def setfunc(self, timeout=None, name=None): def get_trigger(self, timeout=None, name=None):
"""create a new single event and return its set method """create a new single event and return its set method
as a convenience method as a convenience method

View File

@ -257,6 +257,9 @@ class Module(HasAccessibles):
self.name = name self.name = name
self.valueCallbacks = {} self.valueCallbacks = {}
self.errorCallbacks = {} self.errorCallbacks = {}
self.earlyInitDone = False
self.initModuleDone = False
self.startModuleDone = False
errors = [] errors = []
# handle module properties # handle module properties
@ -523,11 +526,25 @@ class Module(HasAccessibles):
return False return False
def earlyInit(self): def earlyInit(self):
# may be overriden in derived classes to init stuff """initialise module with stuff to be done before all modules are created"""
self.log.debug('empty %s.earlyInit()' % self.__class__.__name__) self.earlyInitDone = True
def initModule(self): def initModule(self):
self.log.debug('empty %s.initModule()' % self.__class__.__name__) """initialise module with stuff to be done after all modules are created"""
self.initModuleDone = True
def startModule(self, start_events):
"""runs after init of all modules
when a thread is started, a trigger function may signal that it
has finished its initial work
start_events.get_trigger(<timeout>) creates such a trigger and
registers it in the server for waiting
<timeout> defaults to 30 seconds
"""
if self.writeDict:
mkthread(self.writeInitParams, start_events.get_trigger())
self.startModuleDone = True
def pollOneParam(self, pname): def pollOneParam(self, pname):
"""poll parameter <pname> with proper error handling""" """poll parameter <pname> with proper error handling"""
@ -562,15 +579,6 @@ class Module(HasAccessibles):
if started_callback: if started_callback:
started_callback() started_callback()
def startModule(self, started_callback):
"""runs after init of all modules
started_callback to be called when the thread spawned by startModule
has finished its initial work
might return a timeout value, if different from default
"""
mkthread(self.writeInitParams, started_callback)
class Readable(Module): class Readable(Module):
"""basic readable module""" """basic readable module"""
@ -590,13 +598,13 @@ class Readable(Module):
pollinterval = Parameter('sleeptime between polls', FloatRange(0.1, 120), pollinterval = Parameter('sleeptime between polls', FloatRange(0.1, 120),
default=5, readonly=False) default=5, readonly=False)
def startModule(self, started_callback): def startModule(self, start_events):
"""start basic polling thread""" """start basic polling thread"""
if self.pollerClass and issubclass(self.pollerClass, BasicPoller): if self.pollerClass and issubclass(self.pollerClass, BasicPoller):
# use basic poller for legacy code # use basic poller for legacy code
mkthread(self.__pollThread, started_callback) mkthread(self.__pollThread, start_events.get_trigger(timeout=30))
else: else:
super().startModule(started_callback) super().startModule(start_events)
def __pollThread(self, started_callback): def __pollThread(self, started_callback):
while True: while True:

View File

@ -218,7 +218,7 @@ class Poller(PollerBase):
"""start poll loop """start poll loop
To be called as a thread. After all parameters are polled once first, To be called as a thread. After all parameters are polled once first,
started_callback is called. To be called in Module.start_module. started_callback is called. To be called in Module.startModule.
poll strategy: poll strategy:
Slow polls are performed with lower priority than regular and dynamic polls. Slow polls are performed with lower priority than regular and dynamic polls.

View File

@ -144,10 +144,12 @@ class SecNode(Module):
uri = Property('uri of a SEC node', datatype=StringType()) uri = Property('uri of a SEC node', datatype=StringType())
def earlyInit(self): def earlyInit(self):
super().earlyInit()
self.secnode = SecopClient(self.uri, self.log) self.secnode = SecopClient(self.uri, self.log)
def startModule(self, started_callback): def startModule(self, start_events):
self.secnode.spawn_connect(started_callback) super().startModule(start_events)
self.secnode.spawn_connect(start_events.get_trigger())
@Command(StringType(), result=StringType()) @Command(StringType(), result=StringType())
def request(self, msg): def request(self, msg):

View File

@ -27,13 +27,12 @@ import ast
import configparser import configparser
import os import os
import sys import sys
import threading
import time
import traceback import traceback
from collections import OrderedDict from collections import OrderedDict
from secop.errors import ConfigError, SECoPError from secop.errors import ConfigError, SECoPError
from secop.lib import formatException, get_class, generalConfig from secop.lib import formatException, get_class, generalConfig
from secop.lib.multievent import MultiEvent
from secop.modules import Attached from secop.modules import Attached
from secop.params import PREDEFINED_ACCESSIBLES from secop.params import PREDEFINED_ACCESSIBLES
@ -267,6 +266,7 @@ class Server:
errors.append('error creating %s' % modname) errors.append('error creating %s' % modname)
poll_table = dict() poll_table = dict()
missing_super = set()
# all objs created, now start them up and interconnect # all objs created, now start them up and interconnect
for modname, modobj in self.modules.items(): for modname, modobj in self.modules.items():
self.log.info('registering module %r' % modname) self.log.info('registering module %r' % modname)
@ -276,6 +276,9 @@ class Server:
modobj.pollerClass.add_to_table(poll_table, modobj) modobj.pollerClass.add_to_table(poll_table, modobj)
# also call earlyInit on the modules # also call earlyInit on the modules
modobj.earlyInit() modobj.earlyInit()
if not modobj.earlyInitDone:
missing_super.add('%s was not called, probably missing super call'
% modobj.earlyInit.__qualname__)
# handle attached modules # handle attached modules
for modname, modobj in self.modules.items(): for modname, modobj in self.modules.items():
@ -291,11 +294,26 @@ class Server:
for modname, modobj in self.modules.items(): for modname, modobj in self.modules.items():
try: try:
modobj.initModule() modobj.initModule()
if not modobj.initModuleDone:
missing_super.add('%s was not called, probably missing super call'
% modobj.initModule.__qualname__)
except Exception as e: except Exception as e:
if failure_traceback is None: if failure_traceback is None:
failure_traceback = traceback.format_exc() failure_traceback = traceback.format_exc()
errors.append('error initializing %s: %r' % (modname, e)) errors.append('error initializing %s: %r' % (modname, e))
if self._testonly:
return
start_events = MultiEvent(default_timeout=30)
for modname, modobj in self.modules.items():
# startModule must return either a timeout value or None (default 30 sec)
start_events.name = 'module %s' % modname
modobj.startModule(start_events)
if not modobj.startModuleDone:
missing_super.add('%s was not called, probably missing super call'
% modobj.startModule.__qualname__)
errors.extend(missing_super)
if errors: if errors:
for errtxt in errors: for errtxt in errors:
for line in errtxt.split('\n'): for line in errtxt.split('\n'):
@ -307,23 +325,16 @@ class Server:
sys.stderr.write(failure_traceback) sys.stderr.write(failure_traceback)
sys.exit(1) sys.exit(1)
if self._testonly: for (_, pollname) , poller in poll_table.items():
return start_events.name = 'poller %s' % pollname
start_events = []
for modname, modobj in self.modules.items():
event = threading.Event()
# startModule must return either a timeout value or None (default 30 sec)
timeout = modobj.startModule(started_callback=event.set) or 30
start_events.append((time.time() + timeout, 'module %s' % modname, event))
for poller in poll_table.values():
event = threading.Event()
# poller.start must return either a timeout value or None (default 30 sec) # poller.start must return either a timeout value or None (default 30 sec)
timeout = poller.start(started_callback=event.set) or 30 poller.start(start_events.get_trigger())
start_events.append((time.time() + timeout, repr(poller), event))
self.log.info('waiting for modules and pollers being started') self.log.info('waiting for modules and pollers being started')
for deadline, name, event in sorted(start_events): start_events.name = None
if not event.wait(timeout=max(0, deadline - time.time())): if not start_events.wait():
self.log.info('WARNING: timeout when starting %s' % name) # some timeout happened
for name in start_events.waiting_for():
self.log.warning('timeout when starting %s' % name)
self.log.info('all modules and pollers started') self.log.info('all modules and pollers started')
history_path = os.environ.get('FRAPPY_HISTORY') history_path = os.environ.get('FRAPPY_HISTORY')
if history_path: if history_path:

View File

@ -60,6 +60,7 @@ class SimBase:
return object.__new__(type('SimBase_%s' % devname, (cls,), attrs)) return object.__new__(type('SimBase_%s' % devname, (cls,), attrs))
def initModule(self): def initModule(self):
super().initModule()
self._sim_thread = mkthread(self._sim) self._sim_thread = mkthread(self._sim)
def _sim(self): def _sim(self):

View File

@ -111,6 +111,7 @@ class Cryostat(CryoBase):
group='stability') group='stability')
def initModule(self): def initModule(self):
super().initModule()
self._stopflag = False self._stopflag = False
self._thread = mkthread(self.thread) self._thread = mkthread(self.thread)

View File

@ -133,6 +133,7 @@ class MagneticField(Drivable):
status = Parameter(datatype=TupleOf(EnumType(Status), StringType())) status = Parameter(datatype=TupleOf(EnumType(Status), StringType()))
def initModule(self): def initModule(self):
super().initModule()
self._state = Enum('state', idle=1, switch_on=2, switch_off=3, ramp=4).idle self._state = Enum('state', idle=1, switch_on=2, switch_off=3, ramp=4).idle
self._heatswitch = self.DISPATCHER.get_module(self.heatswitch) self._heatswitch = self.DISPATCHER.get_module(self.heatswitch)
_thread = threading.Thread(target=self._thread) _thread = threading.Thread(target=self._thread)
@ -235,6 +236,7 @@ class SampleTemp(Drivable):
) )
def initModule(self): def initModule(self):
super().initModule()
_thread = threading.Thread(target=self._thread) _thread = threading.Thread(target=self._thread)
_thread.daemon = True _thread.daemon = True
_thread.start() _thread.start()

View File

@ -376,8 +376,8 @@ class AnalogInput(PyTangoDevice, Readable):
The AnalogInput handles all devices only delivering an analogue value. The AnalogInput handles all devices only delivering an analogue value.
""" """
def startModule(self, started_callback): def startModule(self, start_events):
super().startModule(started_callback) super().startModule(start_events)
try: try:
# query unit from tango and update value property # query unit from tango and update value property
attrInfo = self._dev.attribute_query('value') attrInfo = self._dev.attribute_query('value')
@ -454,8 +454,8 @@ class AnalogOutput(PyTangoDevice, Drivable):
self._history = [] # will keep (timestamp, value) tuple self._history = [] # will keep (timestamp, value) tuple
self._timeout = None # keeps the time at which we will timeout, or None self._timeout = None # keeps the time at which we will timeout, or None
def startModule(self, started_callback): def startModule(self, start_events):
super().startModule(started_callback) super().startModule(start_events)
# query unit from tango and update value property # query unit from tango and update value property
attrInfo = self._dev.attribute_query('value') attrInfo = self._dev.attribute_query('value')
# prefer configured unit if nothing is set on the Tango device, else # prefer configured unit if nothing is set on the Tango device, else

View File

@ -76,8 +76,8 @@ class Main(HasIodev, Drivable):
def register_channel(self, modobj): def register_channel(self, modobj):
self._channels[modobj.channel] = modobj self._channels[modobj.channel] = modobj
def startModule(self, started_callback): def startModule(self, start_events):
started_callback() super().startModule(start_events)
for ch in range(1, 16): for ch in range(1, 16):
if ch not in self._channels: if ch not in self._channels:
self.sendRecv('INSET %d,0,0,0,0,0;INSET?%d' % (ch, ch)) self.sendRecv('INSET %d,0,0,0,0,0;INSET?%d' % (ch, ch))

View File

@ -89,6 +89,7 @@ class Main(Communicator):
pollerClass = Poller pollerClass = Poller
def earlyInit(self): def earlyInit(self):
super().earlyInit()
self.modules = {} self.modules = {}
self._ppms_device = ppmshw.QDevice(self.class_id) self._ppms_device = ppmshw.QDevice(self.class_id)
self.lock = threading.Lock() self.lock = threading.Lock()
@ -132,6 +133,11 @@ class PpmsBase(HasIodev, Readable):
"""common base for all ppms modules""" """common base for all ppms modules"""
iodev = Attached() iodev = Attached()
# polling is done by the main module
# and PPMS does not deliver really more fresh values when polled more often
value = Parameter(poll=False, needscfg=False)
status = Parameter(poll=False, needscfg=False)
pollerClass = Poller pollerClass = Poller
enabled = True # default, if no parameter enable is defined enabled = True # default, if no parameter enable is defined
_last_settings = None # used by several modules _last_settings = None # used by several modules
@ -142,23 +148,9 @@ class PpmsBase(HasIodev, Readable):
pollinterval = Parameter(export=False) pollinterval = Parameter(export=False)
def initModule(self): def initModule(self):
super().initModule()
self._iodev.register(self) self._iodev.register(self)
def startModule(self, started_callback):
# no polls except on main module
started_callback()
def read_value(self):
# polling is done by the main module
# and PPMS does not deliver really more fresh values when polled more often
return Done
def read_status(self):
# polling is done by the main module
# and PPMS does not deliver really fresh status values anyway: the status is not
# changed immediately after a target change!
return Done
def update_value_status(self, value, packed_status): def update_value_status(self, value, packed_status):
# update value and status # update value and status
# to be reimplemented for modules looking at packed_status # to be reimplemented for modules looking at packed_status
@ -175,7 +167,7 @@ class PpmsBase(HasIodev, Readable):
class Channel(PpmsBase): class Channel(PpmsBase):
"""channel base class""" """channel base class"""
value = Parameter('main value of channels', poll=True) value = Parameter('main value of channels')
enabled = Parameter('is this channel used?', readonly=False, poll=False, enabled = Parameter('is this channel used?', readonly=False, poll=False,
datatype=BoolType(), default=False) datatype=BoolType(), default=False)
@ -380,8 +372,8 @@ class Temp(PpmsBase, Drivable):
# pylint: disable=invalid-name # pylint: disable=invalid-name
ApproachMode = Enum('ApproachMode', fast_settle=0, no_overshoot=1) ApproachMode = Enum('ApproachMode', fast_settle=0, no_overshoot=1)
value = Parameter(datatype=FloatRange(unit='K'), poll=True) value = Parameter(datatype=FloatRange(unit='K'))
status = Parameter(datatype=StatusType(Status), poll=True) status = Parameter(datatype=StatusType(Status))
target = Parameter(datatype=FloatRange(1.7, 402.0, unit='K'), poll=False, needscfg=False) target = Parameter(datatype=FloatRange(1.7, 402.0, unit='K'), poll=False, needscfg=False)
setpoint = Parameter('intermediate set point', setpoint = Parameter('intermediate set point',
datatype=FloatRange(1.7, 402.0, unit='K'), handler=temp) datatype=FloatRange(1.7, 402.0, unit='K'), handler=temp)
@ -568,8 +560,8 @@ class Field(PpmsBase, Drivable):
PersistentMode = Enum('PersistentMode', persistent=0, driven=1) PersistentMode = Enum('PersistentMode', persistent=0, driven=1)
ApproachMode = Enum('ApproachMode', linear=0, no_overshoot=1, oscillate=2) ApproachMode = Enum('ApproachMode', linear=0, no_overshoot=1, oscillate=2)
value = Parameter(datatype=FloatRange(unit='T'), poll=True) value = Parameter(datatype=FloatRange(unit='T'))
status = Parameter(datatype=StatusType(Status), poll=True) status = Parameter(datatype=StatusType(Status))
target = Parameter(datatype=FloatRange(-15, 15, unit='T'), handler=field) target = Parameter(datatype=FloatRange(-15, 15, unit='T'), handler=field)
ramp = Parameter('ramping speed', readonly=False, handler=field, ramp = Parameter('ramping speed', readonly=False, handler=field,
datatype=FloatRange(0.064, 1.19, unit='T/min')) datatype=FloatRange(0.064, 1.19, unit='T/min'))
@ -696,7 +688,7 @@ class Position(PpmsBase, Drivable):
move = IOHandler('move', 'MOVE?', '%g,%g,%g') move = IOHandler('move', 'MOVE?', '%g,%g,%g')
Status = Drivable.Status Status = Drivable.Status
value = Parameter(datatype=FloatRange(unit='deg'), poll=True) value = Parameter(datatype=FloatRange(unit='deg'))
target = Parameter(datatype=FloatRange(-720., 720., unit='deg'), handler=move) target = Parameter(datatype=FloatRange(-720., 720., unit='deg'), handler=move)
enabled = Parameter('is this channel used?', readonly=False, poll=False, enabled = Parameter('is this channel used?', readonly=False, poll=False,
datatype=BoolType(), default=True) datatype=BoolType(), default=True)

View File

@ -185,12 +185,12 @@ class Motor(PersistentMixin, HasIodev, Drivable):
value = result * scale value = result * scale
return value return value
def startModule(self, started_callback): def startModule(self, start_events):
# get encoder value from motor. at this stage self.encoder contains the persistent value # get encoder value from motor. at this stage self.encoder contains the persistent value
encoder = self.get('encoder') encoder = self.get('encoder')
encoder += self.zero encoder += self.zero
self.fix_encoder(encoder) self.fix_encoder(encoder)
super().startModule(started_callback) super().startModule(start_events)
def fix_encoder(self, encoder_from_hw): def fix_encoder(self, encoder_from_hw):
"""fix encoder value """fix encoder value

View File

@ -22,8 +22,6 @@
# ***************************************************************************** # *****************************************************************************
"""test data types.""" """test data types."""
import threading
import pytest import pytest
from secop.datatypes import BoolType, FloatRange, StringType, IntRange from secop.datatypes import BoolType, FloatRange, StringType, IntRange
@ -31,6 +29,7 @@ from secop.errors import ProgrammingError, ConfigError
from secop.modules import Communicator, Drivable, Readable, Module from secop.modules import Communicator, Drivable, Readable, Module
from secop.params import Command, Parameter from secop.params import Command, Parameter
from secop.poller import BasicPoller from secop.poller import BasicPoller
from secop.lib.multievent import MultiEvent
class DispatcherStub: class DispatcherStub:
@ -69,8 +68,8 @@ def test_Communicator():
o = Communicator('communicator', LoggerStub(), {'.description':''}, ServerStub({})) o = Communicator('communicator', LoggerStub(), {'.description':''}, ServerStub({}))
o.earlyInit() o.earlyInit()
o.initModule() o.initModule()
event = threading.Event() event = MultiEvent()
o.startModule(event.set) o.startModule(event)
assert event.is_set() # event should be set immediately assert event.is_set() # event should be set immediately
@ -175,8 +174,8 @@ def test_ModuleMagic():
'value': 'first'} 'value': 'first'}
assert updates.pop('o1') == expectedBeforeStart assert updates.pop('o1') == expectedBeforeStart
o1.earlyInit() o1.earlyInit()
event = threading.Event() event = MultiEvent()
o1.startModule(event.set) o1.startModule(event)
event.wait() event.wait()
# should contain polled values # should contain polled values
expectedAfterStart = {'status': (Drivable.Status.IDLE, ''), expectedAfterStart = {'status': (Drivable.Status.IDLE, ''),
@ -189,8 +188,8 @@ def test_ModuleMagic():
expectedBeforeStart['a1'] = 2.7 expectedBeforeStart['a1'] = 2.7
assert updates.pop('o2') == expectedBeforeStart assert updates.pop('o2') == expectedBeforeStart
o2.earlyInit() o2.earlyInit()
event = threading.Event() event = MultiEvent()
o2.startModule(event.set) o2.startModule(event)
event.wait() event.wait()
# value has changed type, b2 and a1 are written # value has changed type, b2 and a1 are written
expectedAfterStart.update(value=0, b2=True, a1=2.7) expectedAfterStart.update(value=0, b2=True, a1=2.7)

View File

@ -26,8 +26,8 @@ from secop.lib.multievent import MultiEvent
def test_without_timeout(): def test_without_timeout():
m = MultiEvent() m = MultiEvent()
s1 = m.setfunc(name='s1') s1 = m.get_trigger(name='s1')
s2 = m.setfunc(name='s2') s2 = m.get_trigger(name='s2')
assert not m.wait(0) assert not m.wait(0)
assert m.deadline() is None assert m.deadline() is None
assert m.waiting_for() == {'s1', 's2'} assert m.waiting_for() == {'s1', 's2'}
@ -45,10 +45,10 @@ def test_with_timeout(monkeypatch):
m = MultiEvent() m = MultiEvent()
assert m.deadline() == 0 assert m.deadline() == 0
m.name = 's1' m.name = 's1'
s1 = m.setfunc(10) s1 = m.get_trigger(10)
assert m.deadline() == 1010 assert m.deadline() == 1010
m.name = 's2' m.name = 's2'
s2 = m.setfunc(20) s2 = m.get_trigger(20)
assert m.deadline() == 1020 assert m.deadline() == 1020
current_time += 21 current_time += 21
assert not m.wait(0) assert not m.wait(0)