refactor(npoint): cleanup

This commit is contained in:
2024-10-02 18:27:02 +02:00
parent 8c2d705a89
commit baafa982e3
2 changed files with 96 additions and 42 deletions

View File

@@ -1,17 +1,15 @@
import functools
import threading
import time
from ophyd_devices.utils.controller import Controller, threadlocked
from ophyd_devices.utils.socket import raise_if_disconnected
from prettytable import PrettyTable
from typeguard import typechecked
from ophyd.utils import LimitError, ReadOnlyError
from ophyd import Device, PositionerBase, Signal, SignalRO
from ophyd_devices.utils.socket import SocketIO, SocketSignal
from ophyd import Component as Cpt
import numpy as np
import threading
from ophyd import Component as Cpt
from ophyd import Device, PositionerBase, Signal, SignalRO
from ophyd.status import wait as status_wait
from ophyd.utils import LimitError, ReadOnlyError
from ophyd_devices.utils.controller import Controller, threadlocked
from ophyd_devices.utils.socket import SocketIO, SocketSignal, raise_if_disconnected
from prettytable import PrettyTable
def channel_checked(fcn):
@@ -19,15 +17,24 @@ def channel_checked(fcn):
@functools.wraps(fcn)
def wrapper(self, *args, **kwargs):
# pylint: disable=protected-access
self._check_channel(args[0])
return fcn(self, *args, **kwargs)
return wrapper
class NpointError(Exception):
pass
"""
Base class for Npoint errors.
"""
class NPointController(Controller):
"""
Controller for nPoint piezo stages. This class inherits from the Controller class
and provides a singleton interface to the nPoint controller.
"""
_axes_per_controller = 3
_read_single_loc_bit = "A0"
@@ -49,9 +56,7 @@ class NPointController(Controller):
t = PrettyTable()
t.field_names = ["Channel", "Range", "Position", "Target"]
for ii in range(self._axes_per_controller):
t.add_row(
[ii, self._get_range(ii), self._get_current_pos(ii), self._get_target_pos(ii)]
)
t.add_row([ii, self._get_range(ii), self.get_current_pos(ii), self.get_target_pos(ii)])
print(t)
@channel_checked
@@ -82,7 +87,7 @@ class NPointController(Controller):
return device_range
@channel_checked
def _get_current_pos(self, channel: int) -> float:
def get_current_pos(self, channel: int) -> float:
# for first channel: 0x11 83 13 34
addr = self._channel_base.copy()
addr.extend([f"{19 + 16 * channel:x}", "34"])
@@ -95,7 +100,7 @@ class NPointController(Controller):
return pos
@channel_checked
def _set_target_pos(self, channel: int, pos: float) -> None:
def set_target_pos(self, channel: int, pos: float) -> None:
# for first channel: 0x11 83 12 18 00 00 00 00
addr = self._channel_base.copy()
addr.extend([f"{18 + channel * 16:x}", "18"])
@@ -107,7 +112,7 @@ class NPointController(Controller):
self._put(send_buffer)
@channel_checked
def _get_target_pos(self, channel: int) -> float:
def get_target_pos(self, channel: int) -> float:
# for first channel: 0x11 83 12 18
addr = self._channel_base.copy()
addr.extend([f"{18 + channel * 16:x}", "18"])
@@ -122,17 +127,17 @@ class NPointController(Controller):
def _set_servo(self, channel: int, enable: bool) -> None:
print("Not tested")
return
# for first channel: 0x11 83 10 84 00 00 00 00
addr = self._channel_base.copy()
addr.extend([f"{16 + channel * 16:x}", "84"])
# # for first channel: 0x11 83 10 84 00 00 00 00
# addr = self._channel_base.copy()
# addr.extend([f"{16 + channel * 16:x}", "84"])
if enable:
data = ["00"] * 3 + ["01"]
else:
data = ["00"] * 4
send_buffer = self.__write_single_location_buffer(addr, data)
# if enable:
# data = ["00"] * 3 + ["01"]
# else:
# data = ["00"] * 4
# send_buffer = self.__write_single_location_buffer(addr, data)
self._put(send_buffer)
# self._put(send_buffer)
@channel_checked
def _get_servo(self, channel: int) -> int:
@@ -299,16 +304,23 @@ class NPointController(Controller):
self.off()
class NpointSignalBase(SocketSignal):
"""
Base class for nPoint signals.
"""
def __init__(self, signal_name, **kwargs):
self.signal_name = signal_name
super().__init__(**kwargs)
self.controller:NPointController = self.parent.controller
self.controller: NPointController = self.parent.controller
self.sock = self.parent.controller.sock
class NpointSignalRO(NpointSignalBase):
"""
Base class for read-only signals.
"""
def __init__(self, signal_name, **kwargs):
super().__init__(signal_name, **kwargs)
self._metadata["write_access"] = False
@@ -319,33 +331,58 @@ class NpointSignalRO(NpointSignalBase):
class NpointReadbackSignal(NpointSignalRO):
"""
Signal to read the current position of an nPoint piezo stage.
"""
@threadlocked
def _socket_get(self):
return self.controller._get_current_pos(self.parent.axis_Id_numeric) * self.parent.sign
return self.controller.get_current_pos(self.parent.axis_Id_numeric) * self.parent.sign
class NpointSetpointSignal(NpointSignalBase):
"""
Signal to set the target position of an nPoint piezo stage.
"""
setpoint = 0
@threadlocked
def _socket_get(self):
return self.controller._get_target_pos(self.parent.axis_Id_numeric) * self.parent.sign
return self.controller.get_target_pos(self.parent.axis_Id_numeric) * self.parent.sign
@threadlocked
def _socket_set(self, val):
target_val = val * self.parent.sign
self.setpoint = target_val
return self.controller._set_target_pos(self.parent.axis_Id_numeric, target_val * self.parent.sign)
return self.controller.set_target_pos(
self.parent.axis_Id_numeric, target_val * self.parent.sign
)
class NpointMotorIsMoving(SignalRO):
"""
Signal to indicate whether the motor is currently moving or not.
"""
def set_motor_is_moving(self, value:int) -> None:
def set_motor_is_moving(self, value: int) -> None:
"""
Set the motor_is_moving signal to the specified value.
Args:
value (int): 1 if the motor is moving, 0 otherwise.
"""
self._readback = value
class NPointAxis(Device, PositionerBase):
"""
NPointAxis class, which inherits from Device and PositionerBase. This class
represents an axis of an nPoint piezo stage and provides the necessary
functionality to move the axis and read its current position.
"""
USER_ACCESS = ["controller"]
readback = Cpt(NpointReadbackSignal, signal_name="readback", kind="hinted")
user_setpoint = Cpt(NpointSetpointSignal, signal_name="setpoint")
@@ -374,7 +411,7 @@ class NPointAxis(Device, PositionerBase):
limits=None,
sign=1,
socket_cls=SocketIO,
tolerance:float=0.05,
tolerance: float = 0.05,
**kwargs,
):
self.controller = NPointController(
@@ -485,13 +522,22 @@ class NPointAxis(Device, PositionerBase):
@property
def axis_Id(self):
"""
Return the axis_Id_alpha.
"""
return self._axis_Id_alpha
@axis_Id.setter
def axis_Id(self, val):
def axis_Id(self, val: str):
"""
Set the axis_Id_alpha and axis_Id_numeric based on the alpha value.
Args:
val (str): Single-character axis identifier.
"""
if isinstance(val, str):
if len(val) != 1:
raise ValueError(f"Only single-character axis_Ids are supported.")
raise ValueError("Only single-character axis_Ids are supported.")
self._axis_Id_alpha = val
self._axis_Id_numeric = ord(val.lower()) - 97
else:
@@ -499,13 +545,22 @@ class NPointAxis(Device, PositionerBase):
@property
def axis_Id_numeric(self):
"""
Return the numeric value of the axis_Id.
"""
return self._axis_Id_numeric
@axis_Id_numeric.setter
def axis_Id_numeric(self, val):
def axis_Id_numeric(self, val: int):
"""
Set the axis_Id_numeric and axis_Id_alpha based on the numeric value.
Args:
val (int): Numeric axis identifier.
"""
if isinstance(val, int):
if val > 26:
raise ValueError(f"Numeric value exceeds supported range.")
raise ValueError("Numeric value exceeds supported range.")
self._axis_Id_alpha = val
self._axis_Id_numeric = (chr(val + 97)).capitalize()
else:
@@ -531,4 +586,3 @@ if __name__ == "__main__":
npx.move(10)
print(npx.read())
npx.controller.off()

View File

@@ -52,7 +52,7 @@ def test_axis_put(npointx, pos, msg):
"""
Test that the set target position sends the correct message to the controller.
"""
npointx.controller._set_target_pos(npointx.axis_Id_numeric, pos)
npointx.controller.set_target_pos(npointx.axis_Id_numeric, pos)
npointx.controller.sock.put.assert_called_with(msg)
@@ -103,7 +103,7 @@ def test_axis_get_in(npointx, axis, msg_in, msg_out):
controller's method.
"""
npointx.controller.sock.receive.side_effect = [msg_out]
npointx.controller._get_current_pos(axis)
npointx.controller.get_current_pos(axis)
npointx.controller.sock.put.assert_called_once_with(msg_in)
@@ -122,7 +122,7 @@ def test_get_axis_out_of_range(controller):
Test that an error is raised when trying to get the current position of an invalid axis.
"""
with pytest.raises(ValueError):
controller._get_current_pos(3)
controller.get_current_pos(3)
def test_set_axis_out_of_range(controller):
@@ -130,7 +130,7 @@ def test_set_axis_out_of_range(controller):
Test that an error is raised when trying to set the target position of an invalid axis.
"""
with pytest.raises(ValueError):
controller._set_target_pos(3, 5)
controller.set_target_pos(3, 5)
@pytest.mark.parametrize(