From fb9a17c5e383e2a378d0a3e9cc7cc185dd20c96e Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Wed, 8 Nov 2023 10:02:47 +0100 Subject: [PATCH] fix: changed dependency injection for controller classes; closes #13 --- ophyd_devices/galil/fgalil_ophyd.py | 50 ++++++++------- ophyd_devices/galil/fupr_ophyd.py | 63 +++++++++++-------- ophyd_devices/galil/galil_ophyd.py | 29 +++------ ophyd_devices/npoint/npoint.py | 16 ++++- ophyd_devices/npoint/npoint_ophyd.py | 22 +++++-- ophyd_devices/smaract/smaract_controller.py | 28 +++------ ophyd_devices/smaract/smaract_ophyd.py | 4 +- ophyd_devices/utils/controller.py | 70 ++++++++++----------- ophyd_devices/utils/re_test.py | 17 ----- tests/test_galil.py | 6 +- tests/test_smaract.py | 29 ++++++--- 11 files changed, 167 insertions(+), 167 deletions(-) delete mode 100644 ophyd_devices/utils/re_test.py diff --git a/ophyd_devices/galil/fgalil_ophyd.py b/ophyd_devices/galil/fgalil_ophyd.py index 9272d73..1965ea7 100644 --- a/ophyd_devices/galil/fgalil_ophyd.py +++ b/ophyd_devices/galil/fgalil_ophyd.py @@ -100,10 +100,12 @@ class FlomniGalilMotor(Device, PositionerBase): device_manager=None, **kwargs, ): + self.controller = FlomniGalilController( + socket_cls=socket_cls, socket_host=host, socket_port=port + ) self.axis_Id = axis_Id - self.sign = sign - self.controller = FlomniGalilController(socket=socket_cls(host=host, port=port)) self.controller.set_axis(axis=self, axis_nr=self.axis_Id_numeric) + self.sign = sign self.tolerance = kwargs.pop("tolerance", 0.5) self.device_mapping = kwargs.pop("device_mapping", {}) self.device_manager = device_manager @@ -273,29 +275,29 @@ class FlomniGalilMotor(Device, PositionerBase): return super().stop(success=success) -if __name__ == "__main__": - mock = False - if not mock: - leyey = GalilMotor("H", name="leyey", host="mpc2680.psi.ch", port=8081, sign=-1) - leyey.stage() - status = leyey.move(0, wait=True) - status = leyey.move(10, wait=True) - leyey.read() +# if __name__ == "__main__": +# mock = False +# if not mock: +# leyey = GalilMotor("H", name="leyey", host="mpc2680.psi.ch", port=8081, sign=-1) +# leyey.stage() +# status = leyey.move(0, wait=True) +# status = leyey.move(10, wait=True) +# leyey.read() - leyey.get() - leyey.describe() +# leyey.get() +# leyey.describe() - leyey.unstage() - else: - from ophyd_devices.utils.socket import SocketMock +# leyey.unstage() +# else: +# from ophyd_devices.utils.socket import SocketMock - leyex = GalilMotor( - "G", name="leyex", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock - ) - leyey = GalilMotor( - "H", name="leyey", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock - ) - leyex.stage() - # leyey.stage() +# leyex = GalilMotor( +# "G", name="leyex", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock +# ) +# leyey = GalilMotor( +# "H", name="leyey", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock +# ) +# leyex.stage() +# # leyey.stage() - leyex.controller.galil_show_all() +# leyex.controller.galil_show_all() diff --git a/ophyd_devices/galil/fupr_ophyd.py b/ophyd_devices/galil/fupr_ophyd.py index 61a1b36..375aada 100644 --- a/ophyd_devices/galil/fupr_ophyd.py +++ b/ophyd_devices/galil/fupr_ophyd.py @@ -36,12 +36,21 @@ class FuprGalilController(GalilController): name="GalilController", kind=None, parent=None, - socket=None, + socket_cls=None, + socket_host=None, + socket_port=None, attr_name="", labels=None, ): super().__init__( - name=name, kind=kind, parent=parent, socket=socket, attr_name=attr_name, labels=labels + name=name, + kind=kind, + parent=parent, + socket_cls=socket_cls, + socket_host=socket_host, + socket_port=socket_port, + attr_name=attr_name, + labels=labels, ) self._galil_axis_per_controller = 1 @@ -157,10 +166,12 @@ class FuprGalilMotor(Device, PositionerBase): device_manager=None, **kwargs, ): + self.controller = FuprGalilController( + socket_cls=socket_cls, socket_host=host, socket_port=port + ) self.axis_Id = axis_Id - self.sign = sign - self.controller = FuprGalilController(socket=socket_cls(host=host, port=port)) self.controller.set_axis(axis=self, axis_nr=self.axis_Id_numeric) + self.sign = sign self.tolerance = kwargs.pop("tolerance", 0.5) self.device_mapping = kwargs.pop("device_mapping", {}) self.device_manager = device_manager @@ -330,29 +341,29 @@ class FuprGalilMotor(Device, PositionerBase): return super().stop(success=success) -if __name__ == "__main__": - mock = False - if not mock: - leyey = GalilMotor("H", name="leyey", host="mpc2680.psi.ch", port=8081, sign=-1) - leyey.stage() - status = leyey.move(0, wait=True) - status = leyey.move(10, wait=True) - leyey.read() +# if __name__ == "__main__": +# mock = False +# if not mock: +# leyey = GalilMotor("H", name="leyey", host="mpc2680.psi.ch", port=8081, sign=-1) +# leyey.stage() +# status = leyey.move(0, wait=True) +# status = leyey.move(10, wait=True) +# leyey.read() - leyey.get() - leyey.describe() +# leyey.get() +# leyey.describe() - leyey.unstage() - else: - from ophyd_devices.utils.socket import SocketMock +# leyey.unstage() +# else: +# from ophyd_devices.utils.socket import SocketMock - leyex = GalilMotor( - "G", name="leyex", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock - ) - leyey = GalilMotor( - "H", name="leyey", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock - ) - leyex.stage() - # leyey.stage() +# leyex = GalilMotor( +# "G", name="leyex", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock +# ) +# leyey = GalilMotor( +# "H", name="leyey", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock +# ) +# leyex.stage() +# # leyey.stage() - leyex.controller.galil_show_all() +# leyex.controller.galil_show_all() diff --git a/ophyd_devices/galil/galil_ophyd.py b/ophyd_devices/galil/galil_ophyd.py index bebe61e..717871a 100644 --- a/ophyd_devices/galil/galil_ophyd.py +++ b/ophyd_devices/galil/galil_ophyd.py @@ -59,7 +59,9 @@ class GalilController(Controller): name="GalilController", kind=None, parent=None, - socket=None, + socket_cls=None, + socket_host=None, + socket_port=None, attr_name="", labels=None, ): @@ -68,30 +70,15 @@ class GalilController(Controller): self._axis = [None for axis_num in range(self._galil_axis_per_controller)] super().__init__( name=name, - socket=socket, + socket_cls=socket_cls, + socket_host=socket_host, + socket_port=socket_port, attr_name=attr_name, parent=parent, labels=labels, kind=kind, ) - def on(self, controller_num=0) -> None: - """Open a new socket connection to the controller""" - if not self.connected: - self.sock.open() - self.connected = True - else: - logger.info("The connection has already been established.") - # warnings.warn(f"The connection has already been established.", stacklevel=2) - - def off(self) -> None: - """Close the socket connection to the controller""" - if self.connected: - self.sock.close() - self.connected = False - else: - logger.info("The connection is already closed.") - def set_axis(self, axis: Device, axis_nr: int) -> None: """Assign an axis to a device instance. @@ -462,10 +449,10 @@ class GalilMotor(Device, PositionerBase): device_manager=None, **kwargs, ): + self.controller = GalilController(socket_cls=socket_cls, socket_host=host, socket_port=port) self.axis_Id = axis_Id - self.sign = sign - self.controller = GalilController(socket=socket_cls(host=host, port=port)) self.controller.set_axis(axis=self, axis_nr=self.axis_Id_numeric) + self.sign = sign self.tolerance = kwargs.pop("tolerance", 0.5) self.device_mapping = kwargs.pop("device_mapping", {}) self.device_manager = device_manager diff --git a/ophyd_devices/npoint/npoint.py b/ophyd_devices/npoint/npoint.py index 5121b92..255b021 100644 --- a/ophyd_devices/npoint/npoint.py +++ b/ophyd_devices/npoint/npoint.py @@ -1,12 +1,14 @@ import functools import socket +import threading import time -from ophyd_devices.utils.controller import SingletonController, threadlocked -from ophyd_devices.utils.socket import raise_if_disconnected from prettytable import PrettyTable from typeguard import typechecked +from ophyd_devices.utils.controller import threadlocked +from ophyd_devices.utils.socket import raise_if_disconnected + def channel_checked(fcn): """Decorator to catch attempted access to channels that are not available.""" @@ -60,7 +62,9 @@ class SocketIO: self.is_open = False -class NPointController(SingletonController): +class NPointController: + _controller_instance = None + NUM_CHANNELS = 3 _read_single_loc_bit = "A0" _write_single_loc_bit = "A2" @@ -74,11 +78,17 @@ class NPointController(SingletonController): server_ip: str = "129.129.99.87", server_port: int = 23, ) -> None: + self._lock = threading.RLock() super().__init__() self._server_and_port_name = (server_ip, server_port) self.socket = comm_socket self.connected = False + def __new__(cls, *args, **kwargs): + if not NPointController._controller_instance: + NPointController._controller_instance = object.__new__(cls) + return NPointController._controller_instance + @classmethod def create(cls): return cls(SocketIO()) diff --git a/ophyd_devices/npoint/npoint_ophyd.py b/ophyd_devices/npoint/npoint_ophyd.py index 8bd8353..841efe0 100644 --- a/ophyd_devices/npoint/npoint_ophyd.py +++ b/ophyd_devices/npoint/npoint_ophyd.py @@ -1,15 +1,17 @@ import abc -import socket import functools -import time +import socket import threading +import time -from typeguard import typechecked from ophyd import PositionerBase, Signal -from ophyd.device import Device, Component as Cpt +from ophyd.device import Component as Cpt +from ophyd.device import Device from prettytable import PrettyTable +from typeguard import typechecked + +from ophyd_devices.utils.controller import threadlocked from ophyd_devices.utils.socket import raise_if_disconnected -from ophyd_devices.utils.controller import SingletonController, threadlocked def channel_checked(fcn): @@ -23,7 +25,9 @@ def channel_checked(fcn): return wrapper -class NPointController(SingletonController): +class NPointController: + _controller_instance = None + NUM_CHANNELS = 3 _read_single_loc_bit = "A0" _write_single_loc_bit = "A2" @@ -37,11 +41,17 @@ class NPointController(SingletonController): server_ip: str = "129.129.99.87", server_port: int = 23, ) -> None: + self._lock = threading.RLock() super().__init__() self._server_and_port_name = (server_ip, server_port) self.socket = comm_socket self.connected = False + def __new__(cls, *args, **kwargs): + if not NPointController._controller_instance: + NPointController._controller_instance = object.__new__(cls) + return NPointController._controller_instance + @classmethod def create(cls): return cls(SocketIO()) diff --git a/ophyd_devices/smaract/smaract_controller.py b/ophyd_devices/smaract/smaract_controller.py index 0265003..8763425 100644 --- a/ophyd_devices/smaract/smaract_controller.py +++ b/ophyd_devices/smaract/smaract_controller.py @@ -5,12 +5,13 @@ import logging import os import numpy as np +from typeguard import typechecked + from ophyd_devices.smaract.smaract_errors import ( SmaractCommunicationError, SmaractErrorCode, ) from ophyd_devices.utils.controller import Controller, threadlocked -from typeguard import typechecked logger = logging.getLogger("smaract_controller") @@ -89,7 +90,9 @@ class SmaractController(Controller): name="SmaractController", kind=None, parent=None, - socket=None, + socket_cls=None, + socket_host=None, + socket_port=None, attr_name="", labels=None, ): @@ -98,7 +101,9 @@ class SmaractController(Controller): self._axis = [None for axis_num in range(self._Smaract_axis_per_controller)] super().__init__( name=name, - socket=socket, + socket_cls=socket_cls, + socket_host=socket_host, + socket_port=socket_port, attr_name=attr_name, parent=parent, labels=labels, @@ -106,23 +111,6 @@ class SmaractController(Controller): ) self._sensors = SmaractSensors() - def on(self, controller_num=0): - """Open a new socket connection to the controller""" - if not self.connected: - self.sock.open() - self.connected = True - else: - logger.info("The connection has already been established.") - # warnings.warn(f"The connection has already been established.", stacklevel=2) - - def off(self): - """Close the socket connection to the controller""" - if self.connected: - self.sock.close() - self.connected = False - else: - logger.info("The connection is already closed.") - @axis_checked def set_axis(self, axis_nr, axis): self._axis[axis_nr] = axis diff --git a/ophyd_devices/smaract/smaract_ophyd.py b/ophyd_devices/smaract/smaract_ophyd.py index 638cced..6eff58b 100644 --- a/ophyd_devices/smaract/smaract_ophyd.py +++ b/ophyd_devices/smaract/smaract_ophyd.py @@ -132,7 +132,9 @@ class SmaractMotor(Device, PositionerBase): ): self.axis_Id = axis_Id self.sign = sign - self.controller = SmaractController(socket=socket_cls(host=host, port=port)) + self.controller = SmaractController( + socket_cls=socket_cls, socket_host=host, socket_port=port + ) self.controller.set_axis(self.axis_Id_numeric, axis=self) self.tolerance = kwargs.pop("tolerance", 0.5) diff --git a/ophyd_devices/utils/controller.py b/ophyd_devices/utils/controller.py index 6c0785b..f7ce192 100644 --- a/ophyd_devices/utils/controller.py +++ b/ophyd_devices/utils/controller.py @@ -1,9 +1,10 @@ import functools import threading -import warnings +from bec_lib.core import bec_logger from ophyd.ophydobj import OphydObject -from ophyd_devices.utils.socket import SocketIO + +logger = bec_logger.logger def threadlocked(fcn): @@ -18,24 +19,6 @@ def threadlocked(fcn): return wrapper -class SingletonController: - _controller_instance = None - - def __init__(self) -> None: - self._lock = threading.RLock() - - def on(self): - pass - - def off(self): - pass - - def __new__(cls, *args, **kwargs): - if not SingletonController._controller_instance: - SingletonController._controller_instance = object.__new__(cls) - return SingletonController._controller_instance - - class Controller(OphydObject): _controller_instances = {} @@ -45,24 +28,29 @@ class Controller(OphydObject): self, *, name=None, - socket=None, + socket_cls=None, + socket_host=None, + socket_port=None, attr_name="", parent=None, labels=None, kind=None, ): + self.sock = None + self._socket_cls = socket_cls + self._socket_host = socket_host + self._socket_port = socket_port if not hasattr(self, "_initialized"): super().__init__( name=name, attr_name=attr_name, parent=parent, labels=labels, kind=kind ) self._lock = threading.RLock() - self._initialize(socket) + self._initialize() self._initialized = True - def _initialize(self, socket): + def _initialize(self): self._connected = False self._set_default_values() - self.sock = socket if socket is not None else SocketIO() def _set_default_values(self): # no. of axes controlled by each controller @@ -86,25 +74,35 @@ class Controller(OphydObject): """Get motor instance for a specified controller axis.""" return self._motors[axis] - def on(self, controller_num=0): + def on(self) -> None: """Open a new socket connection to the controller""" - if not self.connected: + if not self.connected or self.sock is None: + self.sock = self._socket_cls(host=self._socket_host, port=self._socket_port) + self.sock.open() self.connected = True else: - warnings.warn(f"The connection has already been established.", stacklevel=2) + logger.info("The connection has already been established.") - def off(self): + def off(self) -> None: """Close the socket connection to the controller""" - self.sock.close() - self.connected = False + if self.connected or self.sock is not None: + self.sock.close() + self.connected = False + self.sock = None + else: + logger.info("The connection is already closed.") def __new__(cls, *args, **kwargs): - socket = kwargs.get("socket") - if not hasattr(socket, "host"): - raise RuntimeError("Socket must specify a host.") - if not hasattr(socket, "port"): - raise RuntimeError("Socket must specify a port.") - host_port = f"{socket.host}:{socket.port}" + socket_cls = kwargs.get("socket_cls") + socket_host = kwargs.get("socket_host") + socket_port = kwargs.get("socket_port") + if not socket_cls: + raise RuntimeError("Socket class must be specified.") + if not socket_host: + raise RuntimeError("Socket host must be specified.") + if not socket_port: + raise RuntimeError("Socket port must be specified.") + host_port = f"{socket_host}:{socket_port}" if host_port not in Controller._controller_instances: Controller._controller_instances[host_port] = object.__new__(cls) return Controller._controller_instances[host_port] diff --git a/ophyd_devices/utils/re_test.py b/ophyd_devices/utils/re_test.py deleted file mode 100644 index 1791e4b..0000000 --- a/ophyd_devices/utils/re_test.py +++ /dev/null @@ -1,17 +0,0 @@ -# from bluesky import RunEngine -# from bluesky.plans import grid_scan -# from bluesky.callbacks.best_effort import BestEffortCallback -# from bluesky.callbacks.mpl_plotting import LivePlot - - -# RE = RunEngine({}) - -# from bluesky.callbacks.best_effort import BestEffortCallback - -# bec = BestEffortCallback() - -# # Send all metadata/data captured to the BestEffortCallback. -# RE.subscribe(bec) -# # RE.subscribe(dummy) - -# # RE(grid_scan(dets, motor1, -10, 10, 10, motor2, -10, 10, 10)) diff --git a/tests/test_galil.py b/tests/test_galil.py index 4d60a19..06e8fb2 100644 --- a/tests/test_galil.py +++ b/tests/test_galil.py @@ -1,8 +1,8 @@ import pytest -from ophyd_devices.galil.galil_ophyd import GalilMotor - from utils import SocketMock +from ophyd_devices.galil.galil_ophyd import GalilMotor + @pytest.mark.parametrize( "pos,msg,sign", @@ -20,8 +20,8 @@ def test_axis_get(pos, msg, sign): sign=sign, socket_cls=SocketMock, ) - leyey.controller.sock.flush_buffer() leyey.controller.on() + leyey.controller.sock.flush_buffer() leyey.controller.sock.buffer_recv = msg val = leyey.read() assert val["leyey"]["value"] == pos diff --git a/tests/test_smaract.py b/tests/test_smaract.py index 6ad5588..2d13584 100644 --- a/tests/test_smaract.py +++ b/tests/test_smaract.py @@ -1,4 +1,6 @@ import pytest +from utils import SocketMock + from ophyd_devices.smaract import SmaractController from ophyd_devices.smaract.smaract_controller import SmaractCommunicationMode from ophyd_devices.smaract.smaract_errors import ( @@ -7,8 +9,6 @@ from ophyd_devices.smaract.smaract_errors import ( ) from ophyd_devices.smaract.smaract_ophyd import SmaractMotor -from utils import SocketMock - @pytest.mark.parametrize( "axis,position,get_message,return_msg", @@ -20,7 +20,8 @@ from utils import SocketMock ], ) def test_get_position(axis, position, get_message, return_msg): - controller = SmaractController(socket=SocketMock(host="dummy", port=123)) + controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) + controller.on() controller.sock.flush_buffer() controller.sock.buffer_recv = return_msg val = controller.get_position(axis) @@ -38,7 +39,8 @@ def test_get_position(axis, position, get_message, return_msg): ], ) def test_axis_is_referenced(axis, is_referenced, get_message, return_msg, exception): - controller = SmaractController(socket=SocketMock(host="dummy", port=123)) + controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) + controller.on() controller.sock.flush_buffer() controller.sock.buffer_recv = return_msg if exception is not None: @@ -60,7 +62,8 @@ def test_axis_is_referenced(axis, is_referenced, get_message, return_msg, except ], ) def test_socket_put_and_receive_raises_exception(return_msg, exception, raised): - controller = SmaractController(socket=SocketMock(host="dummy", port=123)) + controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) + controller.on() controller.sock.flush_buffer() controller.sock.buffer_recv = return_msg with pytest.raises(exception): @@ -84,7 +87,8 @@ def test_socket_put_and_receive_raises_exception(return_msg, exception, raised): ], ) def test_communication_mode(mode, get_message, return_msg): - controller = SmaractController(socket=SocketMock(host="dummy", port=123)) + controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) + controller.on() controller.sock.flush_buffer() controller.sock.buffer_recv = return_msg val = controller.get_communication_mode() @@ -108,7 +112,8 @@ def test_communication_mode(mode, get_message, return_msg): ], ) def test_axis_is_moving(is_moving, get_message, return_msg): - controller = SmaractController(socket=SocketMock(host="dummy", port=123)) + controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) + controller.on() controller.sock.flush_buffer() controller.sock.buffer_recv = return_msg val = controller.is_axis_moving(0) @@ -127,7 +132,8 @@ def test_axis_is_moving(is_moving, get_message, return_msg): ], ) def test_get_sensor_definition(sensor_id, axis, get_msg, return_msg): - controller = SmaractController(socket=SocketMock(host="dummy", port=123)) + controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) + controller.on() controller.sock.flush_buffer() controller.sock.buffer_recv = return_msg sensor = controller.get_sensor_type(axis) @@ -143,7 +149,8 @@ def test_get_sensor_definition(sensor_id, axis, get_msg, return_msg): ], ) def test_set_move_speed(move_speed, axis, get_msg, return_msg): - controller = SmaractController(socket=SocketMock(host="dummy", port=123)) + controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) + controller.on() controller.sock.flush_buffer() controller.sock.buffer_recv = return_msg controller.set_closed_loop_move_speed(axis, move_speed) @@ -159,7 +166,8 @@ def test_set_move_speed(move_speed, axis, get_msg, return_msg): ], ) def test_move_axis_to_absolute_position(pos, axis, hold_time, get_msg, return_msg): - controller = SmaractController(socket=SocketMock(host="dummy", port=123)) + controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) + controller.on() controller.sock.flush_buffer() controller.sock.buffer_recv = return_msg if hold_time is not None: @@ -232,6 +240,7 @@ def test_stop_axis(num_axes, get_msg, return_msg): ) lsmarA.stage() controller = lsmarA.controller + controller.on() controller.sock.flush_buffer() controller.sock.buffer_recv = return_msg controller.stop_all_axes()