diff --git a/secop/poller.py b/secop/poller.py index 4102b76..0a61f42 100644 --- a/secop/poller.py +++ b/secop/poller.py @@ -113,7 +113,8 @@ class Poller(PollerBase): def __init__(self, name): '''create a poller''' self.queues = {polltype: [] for polltype in self.DEFAULT_FACTORS} - self._stopped = Event() + self._event = Event() + self._stopped = False self.maxwait = 3600 self.name = name @@ -142,6 +143,11 @@ class Poller(PollerBase): if not hasattr(module, 'pollinterval'): raise ProgrammingError("module %s must have a pollinterval" % module.name) + if pname == 'is_connected': + if hasattr(module, 'registerReconnectCallback'): + module.registerReconnectCallback(self.name, self.trigger_all) + else: + module.log.warning("%r has 'is_connected' but no 'registerReconnectCallback'" % module) if polltype == AUTO: # covers also pobj.poll == True if pname in ('value', 'status'): polltype = DYNAMIC @@ -184,7 +190,10 @@ class Poller(PollerBase): else: interval = module.pollinterval * factor mininterval = interval - due = max(lastdue + interval, pobj.timestamp + interval * 0.5) + if due == 0: + due = now # do not look at timestamp after trigger_all + else: + due = max(lastdue + interval, pobj.timestamp + interval * 0.5) if now >= due: module.pollOneParam(pname) done = True @@ -194,6 +203,13 @@ class Poller(PollerBase): heapreplace(queue, (due, lastdue, pollitem)) return 0 + def trigger_all(self): + for _, queue in sorted(self.queues.items()): + for idx, (_, lastdue, pollitem) in enumerate(queue): + queue[idx] = (0, lastdue, pollitem) + self._event.set() + return True + def run(self, started_callback): '''start poll loop @@ -222,7 +238,7 @@ class Poller(PollerBase): heapify(queue) started_callback() # signal end of startup nregular = len(self.queues[REGULAR]) - while not self._stopped.is_set(): + while not self._stopped: due = float('inf') for _ in range(nregular): due = min(self.poll_next(DYNAMIC), self.poll_next(REGULAR)) @@ -231,10 +247,12 @@ class Poller(PollerBase): due = min(due, self.poll_next(DYNAMIC), self.poll_next(SLOW)) delay = due - time.time() if delay > 0: - self._stopped.wait(delay) + self._event.wait(delay) + self._event.clear() def stop(self): - self._stopped.set() + self._event.set() + self._stopped = True def __bool__(self): '''is there any poll item?''' diff --git a/secop/stringio.py b/secop/stringio.py index 73ab5a2..4ac9860 100644 --- a/secop/stringio.py +++ b/secop/stringio.py @@ -73,12 +73,14 @@ class StringIO(Communicator): argument=ArrayOf(StringType()), result= ArrayOf(StringType())) } + _reconnectCallbacks = None + def earlyInit(self): self._stream = None self._lock = threading.RLock() self._end_of_line = self.end_of_line.encode(self.encoding) self._connect_error = None - self._last_error = 'not connected' + self._last_error = None def createConnection(self): """create connection @@ -111,7 +113,7 @@ class StringIO(Communicator): if timeout is None or timeout < 0: raise ValueError('illegal timeout %r' % timeout) if not self.is_connected: - raise CommunicationSilentError(self._last_error) + raise CommunicationSilentError(self._last_error or 'not connected') self._stream.settimeout(timeout) try: reply = self._stream.recv(4096) @@ -164,6 +166,7 @@ class StringIO(Communicator): if self._last_error: self.log.info('connected') self._last_error = 'connected' + self.callCallbacks() return Done except Exception as e: if str(e) == self._last_error: @@ -192,6 +195,26 @@ class StringIO(Communicator): raise CommunicationFailedError('bad response: %s does not match %s' % (reply, regexp)) + def registerReconnectCallback(self, name, func): + """register reconnect callback + + if the callback fails or returns False, it is cleared + """ + if self._reconnectCallbacks is None: + self._reconnectCallbacks = {name: func} + else: + self._reconnectCallbacks[name] = func + + def callCallbacks(self): + for key, cb in list(self._reconnectCallbacks.items()): + try: + removeme = not cb() + except Exception as e: + self.log.error('callback: %s' % e) + removeme = True + if removeme: + self._reconnectCallbacks.pop(key) + def do_communicate(self, command): '''send a command and receive a reply diff --git a/test/test_poller.py b/test/test_poller.py index 7b985b3..e23a0b7 100644 --- a/test/test_poller.py +++ b/test/test_poller.py @@ -73,7 +73,10 @@ class Event: artime.sleep(max(0,timeout)) def set(self): - self.flag=True + self.flag = True + + def clear(self): + self.flag = False def is_set(self): return self.flag @@ -198,7 +201,7 @@ def test_Poller(modules): assert len(pollTable) == 1 poller = pollTable[(Poller, 'common_iodev')] artime.stop = poller.stop - poller._stopped = Event() # patch Event.wait + poller._event = Event() # patch Event.wait assert (sum(count.values()) > 0) == bool(poller) @@ -233,6 +236,7 @@ def test_Poller(modules): for module in modules: for pobj in module.parameters.values(): if pobj.poll: + assert pobj.cnt > 0 assert pobj.maxspan <= maxspan[pobj.polltype] * 1.1 assert (pobj.cnt + 1) * pobj.interval >= total * 0.99 assert abs(pobj.span - pobj.interval) < 0.01