refactor(npoint): cleanup
This commit is contained in:
@@ -1,17 +1,15 @@
|
||||
import functools
|
||||
import threading
|
||||
import time
|
||||
|
||||
from ophyd_devices.utils.controller import Controller, threadlocked
|
||||
from ophyd_devices.utils.socket import raise_if_disconnected
|
||||
from prettytable import PrettyTable
|
||||
from typeguard import typechecked
|
||||
from ophyd.utils import LimitError, ReadOnlyError
|
||||
from ophyd import Device, PositionerBase, Signal, SignalRO
|
||||
from ophyd_devices.utils.socket import SocketIO, SocketSignal
|
||||
from ophyd import Component as Cpt
|
||||
import numpy as np
|
||||
import threading
|
||||
from ophyd import Component as Cpt
|
||||
from ophyd import Device, PositionerBase, Signal, SignalRO
|
||||
from ophyd.status import wait as status_wait
|
||||
from ophyd.utils import LimitError, ReadOnlyError
|
||||
from ophyd_devices.utils.controller import Controller, threadlocked
|
||||
from ophyd_devices.utils.socket import SocketIO, SocketSignal, raise_if_disconnected
|
||||
from prettytable import PrettyTable
|
||||
|
||||
|
||||
def channel_checked(fcn):
|
||||
@@ -19,15 +17,24 @@ def channel_checked(fcn):
|
||||
|
||||
@functools.wraps(fcn)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
# pylint: disable=protected-access
|
||||
self._check_channel(args[0])
|
||||
return fcn(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class NpointError(Exception):
|
||||
pass
|
||||
"""
|
||||
Base class for Npoint errors.
|
||||
"""
|
||||
|
||||
|
||||
class NPointController(Controller):
|
||||
"""
|
||||
Controller for nPoint piezo stages. This class inherits from the Controller class
|
||||
and provides a singleton interface to the nPoint controller.
|
||||
"""
|
||||
|
||||
_axes_per_controller = 3
|
||||
_read_single_loc_bit = "A0"
|
||||
@@ -49,9 +56,7 @@ class NPointController(Controller):
|
||||
t = PrettyTable()
|
||||
t.field_names = ["Channel", "Range", "Position", "Target"]
|
||||
for ii in range(self._axes_per_controller):
|
||||
t.add_row(
|
||||
[ii, self._get_range(ii), self._get_current_pos(ii), self._get_target_pos(ii)]
|
||||
)
|
||||
t.add_row([ii, self._get_range(ii), self.get_current_pos(ii), self.get_target_pos(ii)])
|
||||
print(t)
|
||||
|
||||
@channel_checked
|
||||
@@ -82,7 +87,7 @@ class NPointController(Controller):
|
||||
return device_range
|
||||
|
||||
@channel_checked
|
||||
def _get_current_pos(self, channel: int) -> float:
|
||||
def get_current_pos(self, channel: int) -> float:
|
||||
# for first channel: 0x11 83 13 34
|
||||
addr = self._channel_base.copy()
|
||||
addr.extend([f"{19 + 16 * channel:x}", "34"])
|
||||
@@ -95,7 +100,7 @@ class NPointController(Controller):
|
||||
return pos
|
||||
|
||||
@channel_checked
|
||||
def _set_target_pos(self, channel: int, pos: float) -> None:
|
||||
def set_target_pos(self, channel: int, pos: float) -> None:
|
||||
# for first channel: 0x11 83 12 18 00 00 00 00
|
||||
addr = self._channel_base.copy()
|
||||
addr.extend([f"{18 + channel * 16:x}", "18"])
|
||||
@@ -107,7 +112,7 @@ class NPointController(Controller):
|
||||
self._put(send_buffer)
|
||||
|
||||
@channel_checked
|
||||
def _get_target_pos(self, channel: int) -> float:
|
||||
def get_target_pos(self, channel: int) -> float:
|
||||
# for first channel: 0x11 83 12 18
|
||||
addr = self._channel_base.copy()
|
||||
addr.extend([f"{18 + channel * 16:x}", "18"])
|
||||
@@ -122,17 +127,17 @@ class NPointController(Controller):
|
||||
def _set_servo(self, channel: int, enable: bool) -> None:
|
||||
print("Not tested")
|
||||
return
|
||||
# for first channel: 0x11 83 10 84 00 00 00 00
|
||||
addr = self._channel_base.copy()
|
||||
addr.extend([f"{16 + channel * 16:x}", "84"])
|
||||
# # for first channel: 0x11 83 10 84 00 00 00 00
|
||||
# addr = self._channel_base.copy()
|
||||
# addr.extend([f"{16 + channel * 16:x}", "84"])
|
||||
|
||||
if enable:
|
||||
data = ["00"] * 3 + ["01"]
|
||||
else:
|
||||
data = ["00"] * 4
|
||||
send_buffer = self.__write_single_location_buffer(addr, data)
|
||||
# if enable:
|
||||
# data = ["00"] * 3 + ["01"]
|
||||
# else:
|
||||
# data = ["00"] * 4
|
||||
# send_buffer = self.__write_single_location_buffer(addr, data)
|
||||
|
||||
self._put(send_buffer)
|
||||
# self._put(send_buffer)
|
||||
|
||||
@channel_checked
|
||||
def _get_servo(self, channel: int) -> int:
|
||||
@@ -299,16 +304,23 @@ class NPointController(Controller):
|
||||
self.off()
|
||||
|
||||
|
||||
|
||||
class NpointSignalBase(SocketSignal):
|
||||
"""
|
||||
Base class for nPoint signals.
|
||||
"""
|
||||
|
||||
def __init__(self, signal_name, **kwargs):
|
||||
self.signal_name = signal_name
|
||||
super().__init__(**kwargs)
|
||||
self.controller:NPointController = self.parent.controller
|
||||
self.controller: NPointController = self.parent.controller
|
||||
self.sock = self.parent.controller.sock
|
||||
|
||||
|
||||
class NpointSignalRO(NpointSignalBase):
|
||||
"""
|
||||
Base class for read-only signals.
|
||||
"""
|
||||
|
||||
def __init__(self, signal_name, **kwargs):
|
||||
super().__init__(signal_name, **kwargs)
|
||||
self._metadata["write_access"] = False
|
||||
@@ -319,33 +331,58 @@ class NpointSignalRO(NpointSignalBase):
|
||||
|
||||
|
||||
class NpointReadbackSignal(NpointSignalRO):
|
||||
"""
|
||||
Signal to read the current position of an nPoint piezo stage.
|
||||
"""
|
||||
|
||||
@threadlocked
|
||||
def _socket_get(self):
|
||||
return self.controller._get_current_pos(self.parent.axis_Id_numeric) * self.parent.sign
|
||||
|
||||
return self.controller.get_current_pos(self.parent.axis_Id_numeric) * self.parent.sign
|
||||
|
||||
|
||||
class NpointSetpointSignal(NpointSignalBase):
|
||||
"""
|
||||
Signal to set the target position of an nPoint piezo stage.
|
||||
"""
|
||||
|
||||
setpoint = 0
|
||||
|
||||
@threadlocked
|
||||
def _socket_get(self):
|
||||
return self.controller._get_target_pos(self.parent.axis_Id_numeric) * self.parent.sign
|
||||
return self.controller.get_target_pos(self.parent.axis_Id_numeric) * self.parent.sign
|
||||
|
||||
@threadlocked
|
||||
def _socket_set(self, val):
|
||||
target_val = val * self.parent.sign
|
||||
self.setpoint = target_val
|
||||
return self.controller._set_target_pos(self.parent.axis_Id_numeric, target_val * self.parent.sign)
|
||||
|
||||
return self.controller.set_target_pos(
|
||||
self.parent.axis_Id_numeric, target_val * self.parent.sign
|
||||
)
|
||||
|
||||
|
||||
class NpointMotorIsMoving(SignalRO):
|
||||
"""
|
||||
Signal to indicate whether the motor is currently moving or not.
|
||||
"""
|
||||
|
||||
def set_motor_is_moving(self, value:int) -> None:
|
||||
def set_motor_is_moving(self, value: int) -> None:
|
||||
"""
|
||||
Set the motor_is_moving signal to the specified value.
|
||||
|
||||
Args:
|
||||
value (int): 1 if the motor is moving, 0 otherwise.
|
||||
"""
|
||||
self._readback = value
|
||||
|
||||
|
||||
class NPointAxis(Device, PositionerBase):
|
||||
"""
|
||||
NPointAxis class, which inherits from Device and PositionerBase. This class
|
||||
represents an axis of an nPoint piezo stage and provides the necessary
|
||||
functionality to move the axis and read its current position.
|
||||
"""
|
||||
|
||||
USER_ACCESS = ["controller"]
|
||||
readback = Cpt(NpointReadbackSignal, signal_name="readback", kind="hinted")
|
||||
user_setpoint = Cpt(NpointSetpointSignal, signal_name="setpoint")
|
||||
@@ -374,7 +411,7 @@ class NPointAxis(Device, PositionerBase):
|
||||
limits=None,
|
||||
sign=1,
|
||||
socket_cls=SocketIO,
|
||||
tolerance:float=0.05,
|
||||
tolerance: float = 0.05,
|
||||
**kwargs,
|
||||
):
|
||||
self.controller = NPointController(
|
||||
@@ -485,13 +522,22 @@ class NPointAxis(Device, PositionerBase):
|
||||
|
||||
@property
|
||||
def axis_Id(self):
|
||||
"""
|
||||
Return the axis_Id_alpha.
|
||||
"""
|
||||
return self._axis_Id_alpha
|
||||
|
||||
@axis_Id.setter
|
||||
def axis_Id(self, val):
|
||||
def axis_Id(self, val: str):
|
||||
"""
|
||||
Set the axis_Id_alpha and axis_Id_numeric based on the alpha value.
|
||||
|
||||
Args:
|
||||
val (str): Single-character axis identifier.
|
||||
"""
|
||||
if isinstance(val, str):
|
||||
if len(val) != 1:
|
||||
raise ValueError(f"Only single-character axis_Ids are supported.")
|
||||
raise ValueError("Only single-character axis_Ids are supported.")
|
||||
self._axis_Id_alpha = val
|
||||
self._axis_Id_numeric = ord(val.lower()) - 97
|
||||
else:
|
||||
@@ -499,13 +545,22 @@ class NPointAxis(Device, PositionerBase):
|
||||
|
||||
@property
|
||||
def axis_Id_numeric(self):
|
||||
"""
|
||||
Return the numeric value of the axis_Id.
|
||||
"""
|
||||
return self._axis_Id_numeric
|
||||
|
||||
@axis_Id_numeric.setter
|
||||
def axis_Id_numeric(self, val):
|
||||
def axis_Id_numeric(self, val: int):
|
||||
"""
|
||||
Set the axis_Id_numeric and axis_Id_alpha based on the numeric value.
|
||||
|
||||
Args:
|
||||
val (int): Numeric axis identifier.
|
||||
"""
|
||||
if isinstance(val, int):
|
||||
if val > 26:
|
||||
raise ValueError(f"Numeric value exceeds supported range.")
|
||||
raise ValueError("Numeric value exceeds supported range.")
|
||||
self._axis_Id_alpha = val
|
||||
self._axis_Id_numeric = (chr(val + 97)).capitalize()
|
||||
else:
|
||||
@@ -531,4 +586,3 @@ if __name__ == "__main__":
|
||||
npx.move(10)
|
||||
print(npx.read())
|
||||
npx.controller.off()
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ def test_axis_put(npointx, pos, msg):
|
||||
"""
|
||||
Test that the set target position sends the correct message to the controller.
|
||||
"""
|
||||
npointx.controller._set_target_pos(npointx.axis_Id_numeric, pos)
|
||||
npointx.controller.set_target_pos(npointx.axis_Id_numeric, pos)
|
||||
npointx.controller.sock.put.assert_called_with(msg)
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ def test_axis_get_in(npointx, axis, msg_in, msg_out):
|
||||
controller's method.
|
||||
"""
|
||||
npointx.controller.sock.receive.side_effect = [msg_out]
|
||||
npointx.controller._get_current_pos(axis)
|
||||
npointx.controller.get_current_pos(axis)
|
||||
npointx.controller.sock.put.assert_called_once_with(msg_in)
|
||||
|
||||
|
||||
@@ -122,7 +122,7 @@ def test_get_axis_out_of_range(controller):
|
||||
Test that an error is raised when trying to get the current position of an invalid axis.
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
controller._get_current_pos(3)
|
||||
controller.get_current_pos(3)
|
||||
|
||||
|
||||
def test_set_axis_out_of_range(controller):
|
||||
@@ -130,7 +130,7 @@ def test_set_axis_out_of_range(controller):
|
||||
Test that an error is raised when trying to set the target position of an invalid axis.
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
controller._set_target_pos(3, 5)
|
||||
controller.set_target_pos(3, 5)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user