diff --git a/ophyd_devices/utils/socket.py b/ophyd_devices/utils/socket.py index ad23c01..b984907 100644 --- a/ophyd_devices/utils/socket.py +++ b/ophyd_devices/utils/socket.py @@ -82,7 +82,35 @@ def data_type(val): class SocketSignal(abc.ABC, Signal): + """ + Base class for signals that interact with a socket connection. Subclasses + must implement the '_socket_get'and 'socket_set' methods to define how to + read from and write to the socket respectively. The signal also implements + caching of the last read values and a timeout mechanism at 10Hz to avoid + excessive socket reads. The 'get' method implements this caching and timeout + logic, while the 'put' method handles writing to the socket. Both implement + the necessary subscription notifications ('value' for get, 'setpoint' for put) + for value changes. Please note children should only overwrite these methods + if necessary and with care, as they handle caching and subscription notifications. + + Args: + name (str): The name of the signal. + notify_bec (bool): Whether to notify the BEC (Bluesky Event Collector) of value changes. + readback_timeout (float): Time in seconds to wait between socket read attempts before + returning cached value. + """ + SUB_SETPOINT = "setpoint" + SUB_VALUE = "value" + READBACK_TIMEOUT = 0.1 # time to wait in between two readback attemps in seconds, otherwise return cached value + + def __init__( + self, name: str, notify_bec: bool = True, readback_timeout: float = None, **kwargs + ): + super().__init__(name=name, **kwargs) + self.notify_bec = notify_bec + self._readback_timeout = readback_timeout or self.READBACK_TIMEOUT + self._last_readback = 0 @abc.abstractmethod def _socket_get(self): ... @@ -91,7 +119,17 @@ class SocketSignal(abc.ABC, Signal): def _socket_set(self, val): ... def get(self): - self._readback = self._socket_get() + current_time = time.monotonic() + if current_time - self._last_readback > self._readback_timeout: + old_value = self._readback + self._readback = self._socket_get() + self._last_readback = current_time + self._run_subs( + sub_type=self.SUB_VALUE, + old_value=old_value, + value=self._readback, + timestamp=current_time, + ) return self._readback def put( diff --git a/tests/test_socket.py b/tests/test_socket.py index be40a91..05007a1 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -1,9 +1,51 @@ import socket +import time from unittest import mock import pytest +from bec_server.device_server.tests.utils import DMMock -from ophyd_devices.utils.socket import SocketIO +from ophyd_devices.tests.utils import SocketMock +from ophyd_devices.utils.controller import Controller +from ophyd_devices.utils.socket import SocketIO, SocketSignal + + +class DummySocketSignal(SocketSignal): + """Dummy SocketSignal class for testing the SocketSignal interface.""" + + def __init__( + self, name, controller: Controller, notify_bec=True, readback_timeout=None, **kwargs + ): + super().__init__( + name=name, notify_bec=notify_bec, readback_timeout=readback_timeout, **kwargs + ) + self.controller = controller + + def _socket_get(self) -> str: + self._metadata["timestamp"] = time.monotonic() + return self.controller.socket_put_and_receive("get") + + def _socket_set(self, value: str): + self.controller.socket_put_and_receive(value) + + +@pytest.fixture +def controller(): + dm = DMMock() + controller = Controller( + name="controller", + socket_cls=SocketMock, + socket_host="localhost", + socket_port=8080, + device_manager=dm, + ) + controller.on() + return controller + + +@pytest.fixture +def signal(controller): + return DummySocketSignal(name="signal", controller=controller, readback_timeout=0.1) class DummySocket: @@ -85,3 +127,63 @@ def test_close(): socketio.close() assert socketio.sock == None assert socketio.is_open == False + + +def test_socket_signal_get(signal): + """ + Test that the get method of the SocketSignal class correctly retrives values from the socket, + and that it implements the caching and timeout mechanism to avoid excessive socket reads and/or recursions. + """ + # First get should call the socket and cache the value + controller = signal.controller + + controller.sock: SocketMock + controller.sock.buffer_recv = [b"value2", b"value1"] + signal._readback_timeout = 0 + readback = signal.read() + assert readback[signal.name]["value"] == "value2" + readback2 = signal.read() + assert readback2[signal.name]["value"] == "value1" + assert readback[signal.name]["timestamp"] != readback2[signal.name]["timestamp"] + controller.sock.buffer_recv = [b"value2"] + + cb_bucket = [] + read_value = None + signal._readback_timeout = 10 + + def _test_cb(value, old_value, **kwargs): + cb_bucket.append((value, old_value)) + read_value = signal.read() + + signal.subscribe(_test_cb, event_type=signal.SUB_VALUE, run=False) + signal._readback_timeout = 10 + signal._last_readback = 0 # reset the last readback time to force a socket read + readback1 = signal.read() + assert readback1[signal.name]["value"] == "value2" + # The value should be cached, so it should not change + assert cb_bucket == [("value2", "value1")] + readback2 = signal.read() + for entry in ("value", "timestamp"): + assert readback1[signal.name][entry] == readback2[signal.name][entry] + + +def test_socket_signal_put(signal): + """ + Test that the put method of the SocketSignal class correctly sends values to the socket, + and that it implements the necessary subscription notifications for value changes. + """ + controller = signal.controller + controller.sock: SocketMock + cb_bucket = [] + initial_value = signal._readback + + def _test_cb(value, old_value, **kwargs): + cb_bucket.append((value, old_value)) + + signal.subscribe(_test_cb, event_type=signal.SUB_SETPOINT, run=False) + signal.put("new_value") + assert controller.sock.buffer_put == [b"new_value\n"] + assert cb_bucket == [("new_value", initial_value)] + signal.put("another_value") + assert controller.sock.buffer_put == [b"new_value\n", b"another_value\n"] + assert cb_bucket == [("new_value", initial_value), ("another_value", "new_value")]