diff --git a/csaxs_bec/devices/npoint/npoint.py b/csaxs_bec/devices/npoint/npoint.py index 9a8c63e..58f9c50 100644 --- a/csaxs_bec/devices/npoint/npoint.py +++ b/csaxs_bec/devices/npoint/npoint.py @@ -1,17 +1,15 @@ import functools +import threading import time -from ophyd_devices.utils.controller import Controller, threadlocked -from ophyd_devices.utils.socket import raise_if_disconnected -from prettytable import PrettyTable -from typeguard import typechecked -from ophyd.utils import LimitError, ReadOnlyError -from ophyd import Device, PositionerBase, Signal, SignalRO -from ophyd_devices.utils.socket import SocketIO, SocketSignal -from ophyd import Component as Cpt import numpy as np -import threading +from ophyd import Component as Cpt +from ophyd import Device, PositionerBase, Signal, SignalRO from ophyd.status import wait as status_wait +from ophyd.utils import LimitError, ReadOnlyError +from ophyd_devices.utils.controller import Controller, threadlocked +from ophyd_devices.utils.socket import SocketIO, SocketSignal, raise_if_disconnected +from prettytable import PrettyTable def channel_checked(fcn): @@ -19,15 +17,24 @@ def channel_checked(fcn): @functools.wraps(fcn) def wrapper(self, *args, **kwargs): + # pylint: disable=protected-access self._check_channel(args[0]) return fcn(self, *args, **kwargs) return wrapper + class NpointError(Exception): - pass + """ + Base class for Npoint errors. + """ + class NPointController(Controller): + """ + Controller for nPoint piezo stages. This class inherits from the Controller class + and provides a singleton interface to the nPoint controller. + """ _axes_per_controller = 3 _read_single_loc_bit = "A0" @@ -49,9 +56,7 @@ class NPointController(Controller): t = PrettyTable() t.field_names = ["Channel", "Range", "Position", "Target"] for ii in range(self._axes_per_controller): - t.add_row( - [ii, self._get_range(ii), self._get_current_pos(ii), self._get_target_pos(ii)] - ) + t.add_row([ii, self._get_range(ii), self.get_current_pos(ii), self.get_target_pos(ii)]) print(t) @channel_checked @@ -82,7 +87,7 @@ class NPointController(Controller): return device_range @channel_checked - def _get_current_pos(self, channel: int) -> float: + def get_current_pos(self, channel: int) -> float: # for first channel: 0x11 83 13 34 addr = self._channel_base.copy() addr.extend([f"{19 + 16 * channel:x}", "34"]) @@ -95,7 +100,7 @@ class NPointController(Controller): return pos @channel_checked - def _set_target_pos(self, channel: int, pos: float) -> None: + def set_target_pos(self, channel: int, pos: float) -> None: # for first channel: 0x11 83 12 18 00 00 00 00 addr = self._channel_base.copy() addr.extend([f"{18 + channel * 16:x}", "18"]) @@ -107,7 +112,7 @@ class NPointController(Controller): self._put(send_buffer) @channel_checked - def _get_target_pos(self, channel: int) -> float: + def get_target_pos(self, channel: int) -> float: # for first channel: 0x11 83 12 18 addr = self._channel_base.copy() addr.extend([f"{18 + channel * 16:x}", "18"]) @@ -122,17 +127,17 @@ class NPointController(Controller): def _set_servo(self, channel: int, enable: bool) -> None: print("Not tested") return - # for first channel: 0x11 83 10 84 00 00 00 00 - addr = self._channel_base.copy() - addr.extend([f"{16 + channel * 16:x}", "84"]) + # # for first channel: 0x11 83 10 84 00 00 00 00 + # addr = self._channel_base.copy() + # addr.extend([f"{16 + channel * 16:x}", "84"]) - if enable: - data = ["00"] * 3 + ["01"] - else: - data = ["00"] * 4 - send_buffer = self.__write_single_location_buffer(addr, data) + # if enable: + # data = ["00"] * 3 + ["01"] + # else: + # data = ["00"] * 4 + # send_buffer = self.__write_single_location_buffer(addr, data) - self._put(send_buffer) + # self._put(send_buffer) @channel_checked def _get_servo(self, channel: int) -> int: @@ -299,16 +304,23 @@ class NPointController(Controller): self.off() - class NpointSignalBase(SocketSignal): + """ + Base class for nPoint signals. + """ + def __init__(self, signal_name, **kwargs): self.signal_name = signal_name super().__init__(**kwargs) - self.controller:NPointController = self.parent.controller + self.controller: NPointController = self.parent.controller self.sock = self.parent.controller.sock class NpointSignalRO(NpointSignalBase): + """ + Base class for read-only signals. + """ + def __init__(self, signal_name, **kwargs): super().__init__(signal_name, **kwargs) self._metadata["write_access"] = False @@ -319,33 +331,58 @@ class NpointSignalRO(NpointSignalBase): class NpointReadbackSignal(NpointSignalRO): + """ + Signal to read the current position of an nPoint piezo stage. + """ + @threadlocked def _socket_get(self): - return self.controller._get_current_pos(self.parent.axis_Id_numeric) * self.parent.sign + + return self.controller.get_current_pos(self.parent.axis_Id_numeric) * self.parent.sign class NpointSetpointSignal(NpointSignalBase): + """ + Signal to set the target position of an nPoint piezo stage. + """ + setpoint = 0 @threadlocked def _socket_get(self): - return self.controller._get_target_pos(self.parent.axis_Id_numeric) * self.parent.sign + return self.controller.get_target_pos(self.parent.axis_Id_numeric) * self.parent.sign @threadlocked def _socket_set(self, val): target_val = val * self.parent.sign self.setpoint = target_val - return self.controller._set_target_pos(self.parent.axis_Id_numeric, target_val * self.parent.sign) - + return self.controller.set_target_pos( + self.parent.axis_Id_numeric, target_val * self.parent.sign + ) class NpointMotorIsMoving(SignalRO): + """ + Signal to indicate whether the motor is currently moving or not. + """ - def set_motor_is_moving(self, value:int) -> None: + def set_motor_is_moving(self, value: int) -> None: + """ + Set the motor_is_moving signal to the specified value. + + Args: + value (int): 1 if the motor is moving, 0 otherwise. + """ self._readback = value class NPointAxis(Device, PositionerBase): + """ + NPointAxis class, which inherits from Device and PositionerBase. This class + represents an axis of an nPoint piezo stage and provides the necessary + functionality to move the axis and read its current position. + """ + USER_ACCESS = ["controller"] readback = Cpt(NpointReadbackSignal, signal_name="readback", kind="hinted") user_setpoint = Cpt(NpointSetpointSignal, signal_name="setpoint") @@ -374,7 +411,7 @@ class NPointAxis(Device, PositionerBase): limits=None, sign=1, socket_cls=SocketIO, - tolerance:float=0.05, + tolerance: float = 0.05, **kwargs, ): self.controller = NPointController( @@ -485,13 +522,22 @@ class NPointAxis(Device, PositionerBase): @property def axis_Id(self): + """ + Return the axis_Id_alpha. + """ return self._axis_Id_alpha @axis_Id.setter - def axis_Id(self, val): + def axis_Id(self, val: str): + """ + Set the axis_Id_alpha and axis_Id_numeric based on the alpha value. + + Args: + val (str): Single-character axis identifier. + """ if isinstance(val, str): if len(val) != 1: - raise ValueError(f"Only single-character axis_Ids are supported.") + raise ValueError("Only single-character axis_Ids are supported.") self._axis_Id_alpha = val self._axis_Id_numeric = ord(val.lower()) - 97 else: @@ -499,13 +545,22 @@ class NPointAxis(Device, PositionerBase): @property def axis_Id_numeric(self): + """ + Return the numeric value of the axis_Id. + """ return self._axis_Id_numeric @axis_Id_numeric.setter - def axis_Id_numeric(self, val): + def axis_Id_numeric(self, val: int): + """ + Set the axis_Id_numeric and axis_Id_alpha based on the numeric value. + + Args: + val (int): Numeric axis identifier. + """ if isinstance(val, int): if val > 26: - raise ValueError(f"Numeric value exceeds supported range.") + raise ValueError("Numeric value exceeds supported range.") self._axis_Id_alpha = val self._axis_Id_numeric = (chr(val + 97)).capitalize() else: @@ -531,4 +586,3 @@ if __name__ == "__main__": npx.move(10) print(npx.read()) npx.controller.off() - diff --git a/tests/tests_devices/test_npoint_piezo.py b/tests/tests_devices/test_npoint_piezo.py index d53f8f7..a8157d7 100644 --- a/tests/tests_devices/test_npoint_piezo.py +++ b/tests/tests_devices/test_npoint_piezo.py @@ -52,7 +52,7 @@ def test_axis_put(npointx, pos, msg): """ Test that the set target position sends the correct message to the controller. """ - npointx.controller._set_target_pos(npointx.axis_Id_numeric, pos) + npointx.controller.set_target_pos(npointx.axis_Id_numeric, pos) npointx.controller.sock.put.assert_called_with(msg) @@ -103,7 +103,7 @@ def test_axis_get_in(npointx, axis, msg_in, msg_out): controller's method. """ npointx.controller.sock.receive.side_effect = [msg_out] - npointx.controller._get_current_pos(axis) + npointx.controller.get_current_pos(axis) npointx.controller.sock.put.assert_called_once_with(msg_in) @@ -122,7 +122,7 @@ def test_get_axis_out_of_range(controller): Test that an error is raised when trying to get the current position of an invalid axis. """ with pytest.raises(ValueError): - controller._get_current_pos(3) + controller.get_current_pos(3) def test_set_axis_out_of_range(controller): @@ -130,7 +130,7 @@ def test_set_axis_out_of_range(controller): Test that an error is raised when trying to set the target position of an invalid axis. """ with pytest.raises(ValueError): - controller._set_target_pos(3, 5) + controller.set_target_pos(3, 5) @pytest.mark.parametrize(