From 981b87703884299bf193ad73bf65a0e716091cd3 Mon Sep 17 00:00:00 2001 From: wakonig_k Date: Wed, 8 Nov 2023 14:01:20 +0100 Subject: [PATCH] refactor: cleanup and unifying galil classes --- ophyd_devices/galil/fupr_ophyd.py | 24 +- ophyd_devices/galil/galil_ophyd.py | 45 +- ophyd_devices/npoint/npoint_ophyd.py | 428 -------------------- ophyd_devices/smaract/smaract_controller.py | 29 +- ophyd_devices/smaract/smaract_ophyd.py | 6 +- ophyd_devices/utils/controller.py | 72 +++- setup.py | 2 +- tests/test_controller.py | 30 ++ tests/test_galil.py | 128 ++++++ tests/test_smaract.py | 10 + 10 files changed, 242 insertions(+), 532 deletions(-) delete mode 100644 ophyd_devices/npoint/npoint_ophyd.py create mode 100644 tests/test_controller.py diff --git a/ophyd_devices/galil/fupr_ophyd.py b/ophyd_devices/galil/fupr_ophyd.py index 375aada..db5ed83 100644 --- a/ophyd_devices/galil/fupr_ophyd.py +++ b/ophyd_devices/galil/fupr_ophyd.py @@ -30,29 +30,7 @@ logger = bec_logger.logger class FuprGalilController(GalilController): - def __init__( - self, - *, - name="GalilController", - kind=None, - parent=None, - socket_cls=None, - socket_host=None, - socket_port=None, - attr_name="", - labels=None, - ): - super().__init__( - name=name, - kind=kind, - parent=parent, - socket_cls=socket_cls, - socket_host=socket_host, - socket_port=socket_port, - attr_name=attr_name, - labels=labels, - ) - self._galil_axis_per_controller = 1 + _axes_per_controller = 1 def is_axis_moving(self, axis_Id, axis_Id_numeric) -> bool: if axis_Id is None and axis_Id_numeric is not None: diff --git a/ophyd_devices/galil/galil_ophyd.py b/ophyd_devices/galil/galil_ophyd.py index 717871a..6be59de 100644 --- a/ophyd_devices/galil/galil_ophyd.py +++ b/ophyd_devices/galil/galil_ophyd.py @@ -44,6 +44,7 @@ def retry_once(fcn): class GalilController(Controller): + _axes_per_controller = 8 USER_ACCESS = [ "describe", "show_running_threads", @@ -53,42 +54,6 @@ class GalilController(Controller): "lgalil_is_air_off_and_orchestra_enabled", ] - def __init__( - self, - *, - name="GalilController", - kind=None, - parent=None, - socket_cls=None, - socket_host=None, - socket_port=None, - attr_name="", - labels=None, - ): - if not hasattr(self, "_initialized") or not self._initialized: - self._galil_axis_per_controller = 8 - self._axis = [None for axis_num in range(self._galil_axis_per_controller)] - super().__init__( - name=name, - socket_cls=socket_cls, - socket_host=socket_host, - socket_port=socket_port, - attr_name=attr_name, - parent=parent, - labels=labels, - kind=kind, - ) - - def set_axis(self, axis: Device, axis_nr: int) -> None: - """Assign an axis to a device instance. - - Args: - axis (Device): Device instance (e.g. GalilMotor) - axis_nr (int): Controller axis number - - """ - self._axis[axis_nr] = axis - @threadlocked def socket_put(self, val: str) -> None: self.sock.put(f"{val}\r".encode()) @@ -160,7 +125,7 @@ class GalilController(Controller): """ return bool(float(self.socket_put_and_receive("MG allaxref").strip())) - def drive_axis_to_limit(self, axis_Id_numeric, direction: str) -> None: + def drive_axis_to_limit(self, axis_Id_numeric: int, direction: str) -> None: """ Drive an axis to the limit in a specified direction. @@ -215,11 +180,11 @@ class GalilController(Controller): def show_running_threads(self) -> None: t = PrettyTable() t.title = f"Threads on {self.sock.host}:{self.sock.port}" - t.field_names = [str(ax) for ax in range(self._galil_axis_per_controller)] + t.field_names = [str(ax) for ax in range(self._axes_per_controller)] t.add_row( [ "active" if self.is_thread_active(t) else "inactive" - for t in range(self._galil_axis_per_controller) + for t in range(self._axes_per_controller) ] ) print(t) @@ -253,7 +218,7 @@ class GalilController(Controller): "Limits", "Position", ] - for ax in range(self._galil_axis_per_controller): + for ax in range(self._axes_per_controller): axis = self._axis[ax] if axis is not None: t.add_row( diff --git a/ophyd_devices/npoint/npoint_ophyd.py b/ophyd_devices/npoint/npoint_ophyd.py deleted file mode 100644 index 841efe0..0000000 --- a/ophyd_devices/npoint/npoint_ophyd.py +++ /dev/null @@ -1,428 +0,0 @@ -import abc -import functools -import socket -import threading -import time - -from ophyd import PositionerBase, Signal -from ophyd.device import Component as Cpt -from ophyd.device import Device -from prettytable import PrettyTable -from typeguard import typechecked - -from ophyd_devices.utils.controller import threadlocked -from ophyd_devices.utils.socket import raise_if_disconnected - - -def channel_checked(fcn): - """Decorator to catch attempted access to channels that are not available.""" - - @functools.wraps(fcn) - def wrapper(self, *args, **kwargs): - self._check_channel(args[0]) - return fcn(self, *args, **kwargs) - - return wrapper - - -class NPointController: - _controller_instance = None - - NUM_CHANNELS = 3 - _read_single_loc_bit = "A0" - _write_single_loc_bit = "A2" - _trailing_bit = "55" - _range_offset = "78" - _channel_base = ["11", "83"] - - def __init__( - self, - comm_socket: SocketIO, - server_ip: str = "129.129.99.87", - server_port: int = 23, - ) -> None: - self._lock = threading.RLock() - super().__init__() - self._server_and_port_name = (server_ip, server_port) - self.socket = comm_socket - self.connected = False - - def __new__(cls, *args, **kwargs): - if not NPointController._controller_instance: - NPointController._controller_instance = object.__new__(cls) - return NPointController._controller_instance - - @classmethod - def create(cls): - return cls(SocketIO()) - - def show_all(self) -> None: - """Display current status of all channels - - Returns: - None - """ - if not self.connected: - print("npoint controller is currently disabled.") - return - print(f"Connected to controller at {self._server_and_port_name}") - t = PrettyTable() - t.field_names = ["Channel", "Range", "Position", "Target"] - for ii in range(self.NUM_CHANNELS): - t.add_row( - [ - ii, - self._get_range(ii), - self._get_current_pos(ii), - self._get_target_pos(ii), - ] - ) - print(t) - - @threadlocked - def on(self) -> None: - """Enable the NPoint controller and open a new socket. - - Raises: - TimeoutError: Raised if the socket connection raises a timeout. - - Returns: - None - """ - if self.connected: - print("You are already connected to the NPoint controller.") - return - if not self.socket.is_open: - self.socket.open() - try: - self.socket.connect(self._server_and_port_name[0], self._server_and_port_name[1]) - except socket.timeout: - raise TimeoutError( - f"Failed to connect to the specified server and port {self._server_and_port_name}." - ) - except OSError: - print("ERROR while connecting. Let's try again") - self.socket.close() - time.sleep(0.5) - self.socket.open() - self.socket.connect(self._server_and_port_name[0], self._server_and_port_name[1]) - self.connected = True - - @threadlocked - def off(self) -> None: - """Disable the controller and close the socket. - - Returns: - None - """ - self.socket.close() - self.connected = False - - @channel_checked - def _get_range(self, channel: int) -> int: - """Get the range of the specified channel axis. - - Args: - channel (int): Channel for which the range should be requested. - - Raises: - RuntimeError: Raised if the received message doesn't have the expected number of bytes (10). - - Returns: - int: Range - """ - - # for first channel: 0x11 83 10 78 - addr = self._channel_base.copy() - addr.extend([f"{16 + 16 * channel:x}", self._range_offset]) - send_buffer = self.__read_single_location_buffer(addr) - - recvd = self._put_and_receive(send_buffer) - if len(recvd) != 10: - raise RuntimeError( - f"Received buffer is corrupted. Expected 10 bytes and instead got {len(recvd)}" - ) - device_range = self._hex_list_to_int(recvd[5:-1], signed=False) - return device_range - - @channel_checked - def _get_current_pos(self, channel: int) -> float: - # for first channel: 0x11 83 13 34 - addr = self._channel_base.copy() - addr.extend([f"{19 + 16 * channel:x}", "34"]) - send_buffer = self.__read_single_location_buffer(addr) - - recvd = self._put_and_receive(send_buffer) - - pos_buffer = recvd[5:-1] - pos = self._hex_list_to_int(pos_buffer) / 1048574 * 100 - return pos - - @channel_checked - def _set_target_pos(self, channel: int, pos: float) -> None: - # for first channel: 0x11 83 12 18 00 00 00 00 - addr = self._channel_base.copy() - addr.extend([f"{18 + channel * 16:x}", "18"]) - - target = int(round(1048574 / 100 * pos)) - data = [f"{m:02x}" for m in target.to_bytes(4, byteorder="big", signed=True)] - - send_buffer = self.__write_single_location_buffer(addr, data) - self._put(send_buffer) - - @channel_checked - def _get_target_pos(self, channel: int) -> float: - # for first channel: 0x11 83 12 18 - addr = self._channel_base.copy() - addr.extend([f"{18 + channel * 16:x}", "18"]) - send_buffer = self.__read_single_location_buffer(addr) - - recvd = self._put_and_receive(send_buffer) - pos_buffer = recvd[5:-1] - pos = self._hex_list_to_int(pos_buffer) / 1048574 * 100 - return pos - - @channel_checked - def _set_servo(self, channel: int, enable: bool) -> None: - print("Not tested") - return - # for first channel: 0x11 83 10 84 00 00 00 00 - addr = self._channel_base.copy() - addr.extend([f"{16 + channel * 16:x}", "84"]) - - if enable: - data = ["00"] * 3 + ["01"] - else: - data = ["00"] * 4 - send_buffer = self.__write_single_location_buffer(addr, data) - - self._put(send_buffer) - - @channel_checked - def _get_servo(self, channel: int) -> int: - # for first channel: 0x11 83 10 84 00 00 00 00 - addr = self._channel_base.copy() - addr.extend([f"{16 + channel * 16:x}", "84"]) - send_buffer = self.__read_single_location_buffer(addr) - - recvd = self._put_and_receive(send_buffer) - buffer = recvd[5:-1] - status = self._hex_list_to_int(buffer) - return status - - @threadlocked - def _put(self, buffer: list) -> None: - """Translates a list of hex values to bytes and sends them to the socket. - - Args: - buffer (list): List of hex values without leading 0x - - Returns: - None - """ - - buffer = b"".join([bytes.fromhex(m) for m in buffer]) - self.socket.put(buffer) - - @threadlocked - def _put_and_receive(self, msg_hex_list: list) -> list: - """Send msg to socket and wait for a reply. - - Args: - msg_hex_list (list): List of hex values without leading 0x. - - Returns: - list: Received message as a list of hex values - """ - - buffer = b"".join([bytes.fromhex(m) for m in msg_hex_list]) - self.socket.put(buffer) - recv_msg = self.socket.receive() - recv_hex_list = [hex(m) for m in recv_msg] - self._verify_received_msg(msg_hex_list, recv_hex_list) - return recv_hex_list - - def _verify_received_msg(self, in_list: list, out_list: list) -> None: - """Ensure that the first address bits of sent and received messages are the same. - - Args: - in_list (list): list containing the sent message - out_list (list): list containing the received message - - Raises: - RuntimeError: Raised if first two address bits of 'in' and 'out' are not identical - - Returns: - None - """ - - # first, translate hex (str) values to int - in_list_int = [int(val, 16) for val in in_list] - out_list_int = [int(val, 16) for val in out_list] - - # first ints of the reply should be the same. Otherwise something went wrong - if not in_list_int[:2] == out_list_int[:2]: - raise RuntimeError("Connection failure. Please restart the controller.") - - def _check_channel(self, channel: int) -> None: - if channel >= self.NUM_CHANNELS: - raise ValueError( - f"Channel {channel+1} exceeds the available number of channels ({self.NUM_CHANNELS})" - ) - - @staticmethod - def _hex_list_to_int(in_buffer: list, byteorder="little", signed=True) -> int: - """Translate hex list to int. - - Args: - in_buffer (list): Input buffer; received as list of hex values - byteorder (str, optional): Byteorder of in_buffer. Defaults to "little". - signed (bool, optional): Whether the hex list represents a signed int. Defaults to True. - - Returns: - int: Translated integer. - """ - if byteorder == "little": - in_buffer.reverse() - - # make sure that all hex strings have the same format ("FF") - val_hex = [f"{int(m, 16):02x}" for m in in_buffer] - - val_bytes = [bytes.fromhex(m) for m in val_hex] - val = int.from_bytes(b"".join(val_bytes), byteorder="big", signed=signed) - return val - - @staticmethod - def __read_single_location_buffer(addr) -> list: - """Prepare buffer for reading from a single memory location (hex address). - Number of bytes: 6 - Format: 0xA0 [addr] 0x55 - Return Value: 0xA0 [addr] [data] 0x55 - Sample Hex Transmission from PC to LC.400: A0 18 12 83 11 55 - Sample Hex Return Transmission from LC.400 to PC: A0 18 12 83 11 64 00 00 00 55 - - Args: - addr (list): Hex address to read from - - Returns: - list: List of hex values representing the read instruction. - """ - buffer = [] - buffer.append(NPointController._read_single_loc_bit) - if isinstance(addr, list): - addr.reverse() - buffer.extend(addr) - else: - buffer.append(addr) - buffer.append(NPointController._trailing_bit) - - return buffer - - @staticmethod - def __write_single_location_buffer(addr: list, data: list) -> list: - """Prepare buffer for writing to a single memory location (hex address). - Number of bytes: 10 - Format: 0xA2 [addr] [data] 0x55 - Return Value: none - Sample Hex Transmission from PC to C.400: A2 18 12 83 11 E8 03 00 00 55 - - Args: - addr (list): List of hex values representing the address to write to. - data (list): List of hex values representing the data that should be written. - - Returns: - list: List of hex values representing the write instruction. - """ - buffer = [] - buffer.append(NPointController._write_single_loc_bit) - if isinstance(addr, list): - addr.reverse() - buffer.extend(addr) - else: - buffer.append(addr) - - if isinstance(data, list): - data.reverse() - buffer.extend(data) - else: - buffer.append(data) - buffer.append(NPointController._trailing_bit) - return buffer - - @staticmethod - def __read_array(): - raise NotImplementedError - - @staticmethod - def __write_next_command(): - raise NotImplementedError - - def __del__(self): - if self.connected: - print("Closing npoint socket") - self.off() - - -class SocketSignal(abc.ABC, Signal): - def __init__(self, *, name, **kwargs): - super().__init__(**kwargs) - - @abc.abstractmethod - def _socket_get(self): - ... - - @abc.abstractmethod - def _socket_set(self, val): - ... - - -class NPointSignalBase(SocketSignal): - def __init__(self, controller, signal_name, **kwargs): - self.controller = controller - self.signal_name = signal_name - super().__init__(**kwargs) - - -class NPointReadbackSignal(NPointSignalBase): - def _socket_get(self): - pass - - def _socket_set(self, val): - pass - - -class NPointAxis(Device, PositionerBase): - def __init__( - self, - prefix="", - *, - name, - channel=None, - kind=None, - read_attrs=None, - configuration_attrs=None, - parent=None, - **kwargs, - ): - self.channel = channel - self.controller = self._get_controller() - - self.readback = Cpt( - NPointSignal, controller=self.controller, signal_name="RBV", kind="hinted" - ) - self.user_setpoint = Cpt( - NPointSignal, controller=self.controller, signal_name="VAL", kind="normal" - ) - - self.motor_resolution = Cpt( - NPointSignal, controller=self.controller, signal_name="RNGE", kind="config" - ) - self.motor_is_moving = Cpt( - NPointSignal, controller=self.controller, signal_name="MOVN", kind="config" - ) - self.axes_referenced = Cpt( - NPointSignal, controller=self.controller, signal_name="XREF", kind="config" - ) - - def _get_controller(self): - return NPointController() diff --git a/ophyd_devices/smaract/smaract_controller.py b/ophyd_devices/smaract/smaract_controller.py index 8763425..410b1fd 100644 --- a/ophyd_devices/smaract/smaract_controller.py +++ b/ophyd_devices/smaract/smaract_controller.py @@ -11,7 +11,7 @@ from ophyd_devices.smaract.smaract_errors import ( SmaractCommunicationError, SmaractErrorCode, ) -from ophyd_devices.utils.controller import Controller, threadlocked +from ophyd_devices.utils.controller import Controller, axis_checked, threadlocked logger = logging.getLogger("smaract_controller") @@ -21,17 +21,6 @@ class SmaractCommunicationMode(enum.Enum): ASYNC = 1 -def axis_checked(fcn): - """Decorator to catch attempted access to channels that are not available.""" - - @functools.wraps(fcn) - def wrapper(self, *args, **kwargs): - self._check_axis_number(args[0]) - return fcn(self, *args, **kwargs) - - return wrapper - - def retry_once(fcn): """Decorator to rerun a function in case a SmaractCommunicationError was raised. This may happen if the buffer was not empty.""" @@ -82,6 +71,8 @@ class SmaractSensors: class SmaractController(Controller): + _axes_per_controller = 6 + _initialized = False USER_ACCESS = ["socket_put_and_receive", "smaract_show_all", "move_open_loop_steps"] def __init__( @@ -96,9 +87,7 @@ class SmaractController(Controller): attr_name="", labels=None, ): - if not hasattr(self, "_initialized") or not self._initialized: - self._Smaract_axis_per_controller = 6 - self._axis = [None for axis_num in range(self._Smaract_axis_per_controller)] + if not self._initialized: super().__init__( name=name, socket_cls=socket_cls, @@ -111,10 +100,6 @@ class SmaractController(Controller): ) self._sensors = SmaractSensors() - @axis_checked - def set_axis(self, axis_nr, axis): - self._axis[axis_nr] = axis - @threadlocked def socket_put(self, val: str): self.sock.put(f":{val}\n".encode()) @@ -451,12 +436,6 @@ class SmaractController(Controller): t.add_row([None for t in t.field_names]) print(t) - def _check_axis_number(self, axis_Id_numeric: int) -> None: - if axis_Id_numeric >= self._Smaract_axis_per_controller: - raise ValueError( - f"Axis {axis_Id_numeric} exceeds the available number of axes ({self._Smaract_axis_per_controller})" - ) - @axis_checked def _error_str(self, axis_Id_numeric: int, error_number: int): return f":E{axis_Id_numeric},{error_number}" diff --git a/ophyd_devices/smaract/smaract_ophyd.py b/ophyd_devices/smaract/smaract_ophyd.py index 6eff58b..af977ec 100644 --- a/ophyd_devices/smaract/smaract_ophyd.py +++ b/ophyd_devices/smaract/smaract_ophyd.py @@ -130,12 +130,12 @@ class SmaractMotor(Device, PositionerBase): socket_cls=SocketIO, **kwargs, ): - self.axis_Id = axis_Id - self.sign = sign self.controller = SmaractController( socket_cls=socket_cls, socket_host=host, socket_port=port ) - self.controller.set_axis(self.axis_Id_numeric, axis=self) + self.axis_Id = axis_Id + self.sign = sign + self.controller.set_axis(axis=self, axis_nr=self.axis_Id_numeric) self.tolerance = kwargs.pop("tolerance", 0.5) super().__init__( diff --git a/ophyd_devices/utils/controller.py b/ophyd_devices/utils/controller.py index f7ce192..0373fdf 100644 --- a/ophyd_devices/utils/controller.py +++ b/ophyd_devices/utils/controller.py @@ -2,6 +2,7 @@ import functools import threading from bec_lib.core import bec_logger +from ophyd import Device from ophyd.ophydobj import OphydObject logger = bec_logger.logger @@ -19,8 +20,28 @@ def threadlocked(fcn): return wrapper +def axis_checked(fcn): + """Decorator to catch attempted access to channels that are not available.""" + + @functools.wraps(fcn) + def wrapper(self, *args, **kwargs): + if "axis_nr" in kwargs: + self._check_axis_number(kwargs["axis_nr"]) + elif "axis_Id_numeric" in kwargs: + self._check_axis_number(kwargs["axis_Id_numeric"]) + elif args: + self._check_axis_number(args[0]) + return fcn(self, *args, **kwargs) + + return wrapper + + class Controller(OphydObject): + """Base class for all socker-based controllers.""" + _controller_instances = {} + _initialized = False + _axes_per_controller = 1 SUB_CONNECTION_CHANGE = "connection_change" @@ -40,11 +61,12 @@ class Controller(OphydObject): self._socket_cls = socket_cls self._socket_host = socket_host self._socket_port = socket_port - if not hasattr(self, "_initialized"): + if not self._initialized: super().__init__( name=name, attr_name=attr_name, parent=parent, labels=labels, kind=kind ) self._lock = threading.RLock() + self._axis = [] self._initialize() self._initialized = True @@ -54,8 +76,12 @@ class Controller(OphydObject): def _set_default_values(self): # no. of axes controlled by each controller - self._axis_per_controller = 8 - self._motors = [None for axis_num in range(self._axis_per_controller)] + self._axis = [None for axis_num in range(self._axes_per_controller)] + + @classmethod + def _reset_controller(cls): + cls._controller_instances = {} + cls._initialized = False @property def connected(self): @@ -66,13 +92,35 @@ class Controller(OphydObject): self._connected = value self._run_subs(sub_type=self.SUB_CONNECTION_CHANGE) - def set_motor(self, motor, axis): - """Set the motor instance for a specified controller axis.""" - self._motors[axis] = motor + @axis_checked + def set_axis(self, *, axis: Device, axis_nr: int) -> None: + """Assign an axis to a device instance. - def get_motor(self, axis): - """Get motor instance for a specified controller axis.""" - return self._motors[axis] + Args: + axis (Device): Device instance (e.g. GalilMotor) + axis_nr (int): Controller axis number + + """ + self._axis[axis_nr] = axis + + @axis_checked + def get_axis(self, axis_nr: int) -> Device: + """Get device instance for a specified controller axis. + + Args: + axis_nr (int): Controller axis number + + Returns: + Device: Device instance (e.g. GalilMotor) + + """ + return self._axis[axis_nr] + + def _check_axis_number(self, axis_Id_numeric: int) -> None: + if axis_Id_numeric >= self._axes_per_controller: + raise ValueError( + f"Axis {axis_Id_numeric} exceeds the available number of axes ({self._axes_per_controller})" + ) def on(self) -> None: """Open a new socket connection to the controller""" @@ -103,6 +151,6 @@ class Controller(OphydObject): if not socket_port: raise RuntimeError("Socket port must be specified.") host_port = f"{socket_host}:{socket_port}" - if host_port not in Controller._controller_instances: - Controller._controller_instances[host_port] = object.__new__(cls) - return Controller._controller_instances[host_port] + if host_port not in cls._controller_instances: + cls._controller_instances[host_port] = object.__new__(cls) + return cls._controller_instances[host_port] diff --git a/setup.py b/setup.py index 50dde6e..1b7ebb8 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,6 @@ if __name__ == "__main__": "std_daq_client", "pyepics", ], - extras_require={"dev": ["pytest", "pytest-random-order", "black"]}, + extras_require={"dev": ["pytest", "pytest-random-order", "black", "coverage"]}, version=__version__, ) diff --git a/tests/test_controller.py b/tests/test_controller.py new file mode 100644 index 0000000..6d69c13 --- /dev/null +++ b/tests/test_controller.py @@ -0,0 +1,30 @@ +from unittest import mock + +from ophyd_devices.utils.controller import Controller + + +def test_controller_off(): + controller = Controller(socket_cls=mock.MagicMock(), socket_host="dummy", socket_port=123) + controller.on() + with mock.patch.object(controller.sock, "close") as mock_close: + controller.off() + assert controller.sock is None + assert controller.connected is False + mock_close.assert_called_once() + + # make sure it is indempotent + controller.off() + + +def test_controller_on(): + socket_cls = mock.MagicMock() + Controller._controller_instances = {} + controller = Controller(socket_cls=socket_cls, socket_host="dummy", socket_port=123) + controller.on() + assert controller.sock is not None + assert controller.connected is True + socket_cls().open.assert_called_once() + + # make sure it is indempotent + controller.on() + socket_cls().open.assert_called_once() diff --git a/tests/test_galil.py b/tests/test_galil.py index fdc6ef3..5628ef8 100644 --- a/tests/test_galil.py +++ b/tests/test_galil.py @@ -1,3 +1,5 @@ +from unittest import mock + import pytest from utils import SocketMock @@ -61,3 +63,129 @@ def test_axis_put(target_pos, socket_put_messages, socket_get_messages): leyey.controller.sock.buffer_recv = socket_get_messages leyey.user_setpoint.put(target_pos) assert leyey.controller.sock.buffer_put == socket_put_messages + + +@pytest.mark.parametrize( + "axis_nr,direction,socket_put_messages,socket_get_messages", + [ + ( + 0, + "forward", + [ + b"naxis=0\r", + b"ndir=1\r", + b"XQ#NEWPAR\r", + b"XQ#FES\r", + b"MG_BGA\r", + b"MGbcklact[axis]\r", + b"MG_XQ0\r", + b"MG_XQ2\r", + b"MG _LRA, _LFA\r", + ], + [ + b":", + b":", + b":", + b":", + b"0", + b"0", + b"-1", + b"-1", + b"1.000 0.000", + ], + ), + ( + 1, + "reverse", + [ + b"naxis=1\r", + b"ndir=-1\r", + b"XQ#NEWPAR\r", + b"XQ#FES\r", + b"MG_BGB\r", + b"MGbcklact[axis]\r", + b"MG_XQ0\r", + b"MG_XQ2\r", + b"MG _LRB, _LFB\r", + ], + [ + b":", + b":", + b":", + b":", + b"0", + b"0", + b"-1", + b"-1", + b"0.000 1.000", + ], + ), + ], +) +def test_drive_axis_to_limit(axis_nr, direction, socket_put_messages, socket_get_messages): + leyey = GalilMotor("A", name="leyey", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock) + leyey.controller.on() + leyey.controller.sock.flush_buffer() + leyey.controller.sock.buffer_recv = socket_get_messages + leyey.controller.drive_axis_to_limit(axis_nr, direction) + assert leyey.controller.sock.buffer_put == socket_put_messages + + +@pytest.mark.parametrize( + "axis_nr,socket_put_messages,socket_get_messages", + [ + ( + 0, + [ + b"naxis=0\r", + b"XQ#NEWPAR\r", + b"XQ#FRM\r", + b"MG_BGA\r", + b"MGbcklact[axis]\r", + b"MG_XQ0\r", + b"MG_XQ2\r", + b"MG axisref[0]\r", + ], + [ + b":", + b":", + b":", + b"0", + b"0", + b"-1", + b"-1", + b"1.00", + ], + ), + ( + 1, + [ + b"naxis=1\r", + b"XQ#NEWPAR\r", + b"XQ#FRM\r", + b"MG_BGB\r", + b"MGbcklact[axis]\r", + b"MG_XQ0\r", + b"MG_XQ2\r", + b"MG axisref[1]\r", + ], + [ + b":", + b":", + b":", + b"0", + b"0", + b"-1", + b"-1", + b"1.00", + ], + ), + ], +) +def test_find_reference(axis_nr, socket_put_messages, socket_get_messages): + leyey = GalilMotor("A", name="leyey", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock) + leyey.controller.on() + leyey.controller.sock.flush_buffer() + leyey.controller.sock.buffer_recv = socket_get_messages + leyey.controller.find_reference(axis_nr) + assert leyey.controller.sock.buffer_put == socket_put_messages diff --git a/tests/test_smaract.py b/tests/test_smaract.py index 2d13584..6c2d733 100644 --- a/tests/test_smaract.py +++ b/tests/test_smaract.py @@ -20,6 +20,7 @@ from ophyd_devices.smaract.smaract_ophyd import SmaractMotor ], ) def test_get_position(axis, position, get_message, return_msg): + SmaractController._reset_controller() controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) controller.on() controller.sock.flush_buffer() @@ -39,6 +40,7 @@ def test_get_position(axis, position, get_message, return_msg): ], ) def test_axis_is_referenced(axis, is_referenced, get_message, return_msg, exception): + SmaractController._reset_controller() controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) controller.on() controller.sock.flush_buffer() @@ -62,6 +64,7 @@ def test_axis_is_referenced(axis, is_referenced, get_message, return_msg, except ], ) def test_socket_put_and_receive_raises_exception(return_msg, exception, raised): + SmaractController._reset_controller() controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) controller.on() controller.sock.flush_buffer() @@ -87,6 +90,7 @@ def test_socket_put_and_receive_raises_exception(return_msg, exception, raised): ], ) def test_communication_mode(mode, get_message, return_msg): + SmaractController._reset_controller() controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) controller.on() controller.sock.flush_buffer() @@ -112,6 +116,7 @@ def test_communication_mode(mode, get_message, return_msg): ], ) def test_axis_is_moving(is_moving, get_message, return_msg): + SmaractController._reset_controller() controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) controller.on() controller.sock.flush_buffer() @@ -132,6 +137,7 @@ def test_axis_is_moving(is_moving, get_message, return_msg): ], ) def test_get_sensor_definition(sensor_id, axis, get_msg, return_msg): + SmaractController._reset_controller() controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) controller.on() controller.sock.flush_buffer() @@ -149,6 +155,7 @@ def test_get_sensor_definition(sensor_id, axis, get_msg, return_msg): ], ) def test_set_move_speed(move_speed, axis, get_msg, return_msg): + SmaractController._reset_controller() controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) controller.on() controller.sock.flush_buffer() @@ -166,6 +173,7 @@ def test_set_move_speed(move_speed, axis, get_msg, return_msg): ], ) def test_move_axis_to_absolute_position(pos, axis, hold_time, get_msg, return_msg): + SmaractController._reset_controller() controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123) controller.on() controller.sock.flush_buffer() @@ -203,6 +211,7 @@ def test_move_axis_to_absolute_position(pos, axis, hold_time, get_msg, return_ms ], ) def test_move_axis(pos, get_msg, return_msg): + SmaractController._reset_controller() lsmarA = SmaractMotor( "A", name="lsmarA", @@ -230,6 +239,7 @@ def test_move_axis(pos, get_msg, return_msg): ], ) def test_stop_axis(num_axes, get_msg, return_msg): + SmaractController._reset_controller() lsmarA = SmaractMotor( "A", name="lsmarA",