fix: changed dependency injection for controller classes; closes #13

This commit is contained in:
wakonig_k 2023-11-08 10:02:47 +01:00
parent 9080d45075
commit fb9a17c5e3
11 changed files with 167 additions and 167 deletions

View File

@ -100,10 +100,12 @@ class FlomniGalilMotor(Device, PositionerBase):
device_manager=None,
**kwargs,
):
self.controller = FlomniGalilController(
socket_cls=socket_cls, socket_host=host, socket_port=port
)
self.axis_Id = axis_Id
self.sign = sign
self.controller = FlomniGalilController(socket=socket_cls(host=host, port=port))
self.controller.set_axis(axis=self, axis_nr=self.axis_Id_numeric)
self.sign = sign
self.tolerance = kwargs.pop("tolerance", 0.5)
self.device_mapping = kwargs.pop("device_mapping", {})
self.device_manager = device_manager
@ -273,29 +275,29 @@ class FlomniGalilMotor(Device, PositionerBase):
return super().stop(success=success)
if __name__ == "__main__":
mock = False
if not mock:
leyey = GalilMotor("H", name="leyey", host="mpc2680.psi.ch", port=8081, sign=-1)
leyey.stage()
status = leyey.move(0, wait=True)
status = leyey.move(10, wait=True)
leyey.read()
# if __name__ == "__main__":
# mock = False
# if not mock:
# leyey = GalilMotor("H", name="leyey", host="mpc2680.psi.ch", port=8081, sign=-1)
# leyey.stage()
# status = leyey.move(0, wait=True)
# status = leyey.move(10, wait=True)
# leyey.read()
leyey.get()
leyey.describe()
# leyey.get()
# leyey.describe()
leyey.unstage()
else:
from ophyd_devices.utils.socket import SocketMock
# leyey.unstage()
# else:
# from ophyd_devices.utils.socket import SocketMock
leyex = GalilMotor(
"G", name="leyex", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock
)
leyey = GalilMotor(
"H", name="leyey", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock
)
leyex.stage()
# leyey.stage()
# leyex = GalilMotor(
# "G", name="leyex", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock
# )
# leyey = GalilMotor(
# "H", name="leyey", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock
# )
# leyex.stage()
# # leyey.stage()
leyex.controller.galil_show_all()
# leyex.controller.galil_show_all()

View File

@ -36,12 +36,21 @@ class FuprGalilController(GalilController):
name="GalilController",
kind=None,
parent=None,
socket=None,
socket_cls=None,
socket_host=None,
socket_port=None,
attr_name="",
labels=None,
):
super().__init__(
name=name, kind=kind, parent=parent, socket=socket, attr_name=attr_name, labels=labels
name=name,
kind=kind,
parent=parent,
socket_cls=socket_cls,
socket_host=socket_host,
socket_port=socket_port,
attr_name=attr_name,
labels=labels,
)
self._galil_axis_per_controller = 1
@ -157,10 +166,12 @@ class FuprGalilMotor(Device, PositionerBase):
device_manager=None,
**kwargs,
):
self.controller = FuprGalilController(
socket_cls=socket_cls, socket_host=host, socket_port=port
)
self.axis_Id = axis_Id
self.sign = sign
self.controller = FuprGalilController(socket=socket_cls(host=host, port=port))
self.controller.set_axis(axis=self, axis_nr=self.axis_Id_numeric)
self.sign = sign
self.tolerance = kwargs.pop("tolerance", 0.5)
self.device_mapping = kwargs.pop("device_mapping", {})
self.device_manager = device_manager
@ -330,29 +341,29 @@ class FuprGalilMotor(Device, PositionerBase):
return super().stop(success=success)
if __name__ == "__main__":
mock = False
if not mock:
leyey = GalilMotor("H", name="leyey", host="mpc2680.psi.ch", port=8081, sign=-1)
leyey.stage()
status = leyey.move(0, wait=True)
status = leyey.move(10, wait=True)
leyey.read()
# if __name__ == "__main__":
# mock = False
# if not mock:
# leyey = GalilMotor("H", name="leyey", host="mpc2680.psi.ch", port=8081, sign=-1)
# leyey.stage()
# status = leyey.move(0, wait=True)
# status = leyey.move(10, wait=True)
# leyey.read()
leyey.get()
leyey.describe()
# leyey.get()
# leyey.describe()
leyey.unstage()
else:
from ophyd_devices.utils.socket import SocketMock
# leyey.unstage()
# else:
# from ophyd_devices.utils.socket import SocketMock
leyex = GalilMotor(
"G", name="leyex", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock
)
leyey = GalilMotor(
"H", name="leyey", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock
)
leyex.stage()
# leyey.stage()
# leyex = GalilMotor(
# "G", name="leyex", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock
# )
# leyey = GalilMotor(
# "H", name="leyey", host="mpc2680.psi.ch", port=8081, socket_cls=SocketMock
# )
# leyex.stage()
# # leyey.stage()
leyex.controller.galil_show_all()
# leyex.controller.galil_show_all()

View File

@ -59,7 +59,9 @@ class GalilController(Controller):
name="GalilController",
kind=None,
parent=None,
socket=None,
socket_cls=None,
socket_host=None,
socket_port=None,
attr_name="",
labels=None,
):
@ -68,30 +70,15 @@ class GalilController(Controller):
self._axis = [None for axis_num in range(self._galil_axis_per_controller)]
super().__init__(
name=name,
socket=socket,
socket_cls=socket_cls,
socket_host=socket_host,
socket_port=socket_port,
attr_name=attr_name,
parent=parent,
labels=labels,
kind=kind,
)
def on(self, controller_num=0) -> None:
"""Open a new socket connection to the controller"""
if not self.connected:
self.sock.open()
self.connected = True
else:
logger.info("The connection has already been established.")
# warnings.warn(f"The connection has already been established.", stacklevel=2)
def off(self) -> None:
"""Close the socket connection to the controller"""
if self.connected:
self.sock.close()
self.connected = False
else:
logger.info("The connection is already closed.")
def set_axis(self, axis: Device, axis_nr: int) -> None:
"""Assign an axis to a device instance.
@ -462,10 +449,10 @@ class GalilMotor(Device, PositionerBase):
device_manager=None,
**kwargs,
):
self.controller = GalilController(socket_cls=socket_cls, socket_host=host, socket_port=port)
self.axis_Id = axis_Id
self.sign = sign
self.controller = GalilController(socket=socket_cls(host=host, port=port))
self.controller.set_axis(axis=self, axis_nr=self.axis_Id_numeric)
self.sign = sign
self.tolerance = kwargs.pop("tolerance", 0.5)
self.device_mapping = kwargs.pop("device_mapping", {})
self.device_manager = device_manager

View File

@ -1,12 +1,14 @@
import functools
import socket
import threading
import time
from ophyd_devices.utils.controller import SingletonController, threadlocked
from ophyd_devices.utils.socket import raise_if_disconnected
from prettytable import PrettyTable
from typeguard import typechecked
from ophyd_devices.utils.controller import threadlocked
from ophyd_devices.utils.socket import raise_if_disconnected
def channel_checked(fcn):
"""Decorator to catch attempted access to channels that are not available."""
@ -60,7 +62,9 @@ class SocketIO:
self.is_open = False
class NPointController(SingletonController):
class NPointController:
_controller_instance = None
NUM_CHANNELS = 3
_read_single_loc_bit = "A0"
_write_single_loc_bit = "A2"
@ -74,11 +78,17 @@ class NPointController(SingletonController):
server_ip: str = "129.129.99.87",
server_port: int = 23,
) -> None:
self._lock = threading.RLock()
super().__init__()
self._server_and_port_name = (server_ip, server_port)
self.socket = comm_socket
self.connected = False
def __new__(cls, *args, **kwargs):
if not NPointController._controller_instance:
NPointController._controller_instance = object.__new__(cls)
return NPointController._controller_instance
@classmethod
def create(cls):
return cls(SocketIO())

View File

@ -1,15 +1,17 @@
import abc
import socket
import functools
import time
import socket
import threading
import time
from typeguard import typechecked
from ophyd import PositionerBase, Signal
from ophyd.device import Device, Component as Cpt
from ophyd.device import Component as Cpt
from ophyd.device import Device
from prettytable import PrettyTable
from typeguard import typechecked
from ophyd_devices.utils.controller import threadlocked
from ophyd_devices.utils.socket import raise_if_disconnected
from ophyd_devices.utils.controller import SingletonController, threadlocked
def channel_checked(fcn):
@ -23,7 +25,9 @@ def channel_checked(fcn):
return wrapper
class NPointController(SingletonController):
class NPointController:
_controller_instance = None
NUM_CHANNELS = 3
_read_single_loc_bit = "A0"
_write_single_loc_bit = "A2"
@ -37,11 +41,17 @@ class NPointController(SingletonController):
server_ip: str = "129.129.99.87",
server_port: int = 23,
) -> None:
self._lock = threading.RLock()
super().__init__()
self._server_and_port_name = (server_ip, server_port)
self.socket = comm_socket
self.connected = False
def __new__(cls, *args, **kwargs):
if not NPointController._controller_instance:
NPointController._controller_instance = object.__new__(cls)
return NPointController._controller_instance
@classmethod
def create(cls):
return cls(SocketIO())

View File

@ -5,12 +5,13 @@ import logging
import os
import numpy as np
from typeguard import typechecked
from ophyd_devices.smaract.smaract_errors import (
SmaractCommunicationError,
SmaractErrorCode,
)
from ophyd_devices.utils.controller import Controller, threadlocked
from typeguard import typechecked
logger = logging.getLogger("smaract_controller")
@ -89,7 +90,9 @@ class SmaractController(Controller):
name="SmaractController",
kind=None,
parent=None,
socket=None,
socket_cls=None,
socket_host=None,
socket_port=None,
attr_name="",
labels=None,
):
@ -98,7 +101,9 @@ class SmaractController(Controller):
self._axis = [None for axis_num in range(self._Smaract_axis_per_controller)]
super().__init__(
name=name,
socket=socket,
socket_cls=socket_cls,
socket_host=socket_host,
socket_port=socket_port,
attr_name=attr_name,
parent=parent,
labels=labels,
@ -106,23 +111,6 @@ class SmaractController(Controller):
)
self._sensors = SmaractSensors()
def on(self, controller_num=0):
"""Open a new socket connection to the controller"""
if not self.connected:
self.sock.open()
self.connected = True
else:
logger.info("The connection has already been established.")
# warnings.warn(f"The connection has already been established.", stacklevel=2)
def off(self):
"""Close the socket connection to the controller"""
if self.connected:
self.sock.close()
self.connected = False
else:
logger.info("The connection is already closed.")
@axis_checked
def set_axis(self, axis_nr, axis):
self._axis[axis_nr] = axis

View File

@ -132,7 +132,9 @@ class SmaractMotor(Device, PositionerBase):
):
self.axis_Id = axis_Id
self.sign = sign
self.controller = SmaractController(socket=socket_cls(host=host, port=port))
self.controller = SmaractController(
socket_cls=socket_cls, socket_host=host, socket_port=port
)
self.controller.set_axis(self.axis_Id_numeric, axis=self)
self.tolerance = kwargs.pop("tolerance", 0.5)

View File

@ -1,9 +1,10 @@
import functools
import threading
import warnings
from bec_lib.core import bec_logger
from ophyd.ophydobj import OphydObject
from ophyd_devices.utils.socket import SocketIO
logger = bec_logger.logger
def threadlocked(fcn):
@ -18,24 +19,6 @@ def threadlocked(fcn):
return wrapper
class SingletonController:
_controller_instance = None
def __init__(self) -> None:
self._lock = threading.RLock()
def on(self):
pass
def off(self):
pass
def __new__(cls, *args, **kwargs):
if not SingletonController._controller_instance:
SingletonController._controller_instance = object.__new__(cls)
return SingletonController._controller_instance
class Controller(OphydObject):
_controller_instances = {}
@ -45,24 +28,29 @@ class Controller(OphydObject):
self,
*,
name=None,
socket=None,
socket_cls=None,
socket_host=None,
socket_port=None,
attr_name="",
parent=None,
labels=None,
kind=None,
):
self.sock = None
self._socket_cls = socket_cls
self._socket_host = socket_host
self._socket_port = socket_port
if not hasattr(self, "_initialized"):
super().__init__(
name=name, attr_name=attr_name, parent=parent, labels=labels, kind=kind
)
self._lock = threading.RLock()
self._initialize(socket)
self._initialize()
self._initialized = True
def _initialize(self, socket):
def _initialize(self):
self._connected = False
self._set_default_values()
self.sock = socket if socket is not None else SocketIO()
def _set_default_values(self):
# no. of axes controlled by each controller
@ -86,25 +74,35 @@ class Controller(OphydObject):
"""Get motor instance for a specified controller axis."""
return self._motors[axis]
def on(self, controller_num=0):
def on(self) -> None:
"""Open a new socket connection to the controller"""
if not self.connected:
if not self.connected or self.sock is None:
self.sock = self._socket_cls(host=self._socket_host, port=self._socket_port)
self.sock.open()
self.connected = True
else:
warnings.warn(f"The connection has already been established.", stacklevel=2)
logger.info("The connection has already been established.")
def off(self):
def off(self) -> None:
"""Close the socket connection to the controller"""
self.sock.close()
self.connected = False
if self.connected or self.sock is not None:
self.sock.close()
self.connected = False
self.sock = None
else:
logger.info("The connection is already closed.")
def __new__(cls, *args, **kwargs):
socket = kwargs.get("socket")
if not hasattr(socket, "host"):
raise RuntimeError("Socket must specify a host.")
if not hasattr(socket, "port"):
raise RuntimeError("Socket must specify a port.")
host_port = f"{socket.host}:{socket.port}"
socket_cls = kwargs.get("socket_cls")
socket_host = kwargs.get("socket_host")
socket_port = kwargs.get("socket_port")
if not socket_cls:
raise RuntimeError("Socket class must be specified.")
if not socket_host:
raise RuntimeError("Socket host must be specified.")
if not socket_port:
raise RuntimeError("Socket port must be specified.")
host_port = f"{socket_host}:{socket_port}"
if host_port not in Controller._controller_instances:
Controller._controller_instances[host_port] = object.__new__(cls)
return Controller._controller_instances[host_port]

View File

@ -1,17 +0,0 @@
# from bluesky import RunEngine
# from bluesky.plans import grid_scan
# from bluesky.callbacks.best_effort import BestEffortCallback
# from bluesky.callbacks.mpl_plotting import LivePlot
# RE = RunEngine({})
# from bluesky.callbacks.best_effort import BestEffortCallback
# bec = BestEffortCallback()
# # Send all metadata/data captured to the BestEffortCallback.
# RE.subscribe(bec)
# # RE.subscribe(dummy)
# # RE(grid_scan(dets, motor1, -10, 10, 10, motor2, -10, 10, 10))

View File

@ -1,8 +1,8 @@
import pytest
from ophyd_devices.galil.galil_ophyd import GalilMotor
from utils import SocketMock
from ophyd_devices.galil.galil_ophyd import GalilMotor
@pytest.mark.parametrize(
"pos,msg,sign",
@ -20,8 +20,8 @@ def test_axis_get(pos, msg, sign):
sign=sign,
socket_cls=SocketMock,
)
leyey.controller.sock.flush_buffer()
leyey.controller.on()
leyey.controller.sock.flush_buffer()
leyey.controller.sock.buffer_recv = msg
val = leyey.read()
assert val["leyey"]["value"] == pos

View File

@ -1,4 +1,6 @@
import pytest
from utils import SocketMock
from ophyd_devices.smaract import SmaractController
from ophyd_devices.smaract.smaract_controller import SmaractCommunicationMode
from ophyd_devices.smaract.smaract_errors import (
@ -7,8 +9,6 @@ from ophyd_devices.smaract.smaract_errors import (
)
from ophyd_devices.smaract.smaract_ophyd import SmaractMotor
from utils import SocketMock
@pytest.mark.parametrize(
"axis,position,get_message,return_msg",
@ -20,7 +20,8 @@ from utils import SocketMock
],
)
def test_get_position(axis, position, get_message, return_msg):
controller = SmaractController(socket=SocketMock(host="dummy", port=123))
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
controller.sock.buffer_recv = return_msg
val = controller.get_position(axis)
@ -38,7 +39,8 @@ def test_get_position(axis, position, get_message, return_msg):
],
)
def test_axis_is_referenced(axis, is_referenced, get_message, return_msg, exception):
controller = SmaractController(socket=SocketMock(host="dummy", port=123))
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
controller.sock.buffer_recv = return_msg
if exception is not None:
@ -60,7 +62,8 @@ def test_axis_is_referenced(axis, is_referenced, get_message, return_msg, except
],
)
def test_socket_put_and_receive_raises_exception(return_msg, exception, raised):
controller = SmaractController(socket=SocketMock(host="dummy", port=123))
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
controller.sock.buffer_recv = return_msg
with pytest.raises(exception):
@ -84,7 +87,8 @@ def test_socket_put_and_receive_raises_exception(return_msg, exception, raised):
],
)
def test_communication_mode(mode, get_message, return_msg):
controller = SmaractController(socket=SocketMock(host="dummy", port=123))
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
controller.sock.buffer_recv = return_msg
val = controller.get_communication_mode()
@ -108,7 +112,8 @@ def test_communication_mode(mode, get_message, return_msg):
],
)
def test_axis_is_moving(is_moving, get_message, return_msg):
controller = SmaractController(socket=SocketMock(host="dummy", port=123))
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
controller.sock.buffer_recv = return_msg
val = controller.is_axis_moving(0)
@ -127,7 +132,8 @@ def test_axis_is_moving(is_moving, get_message, return_msg):
],
)
def test_get_sensor_definition(sensor_id, axis, get_msg, return_msg):
controller = SmaractController(socket=SocketMock(host="dummy", port=123))
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
controller.sock.buffer_recv = return_msg
sensor = controller.get_sensor_type(axis)
@ -143,7 +149,8 @@ def test_get_sensor_definition(sensor_id, axis, get_msg, return_msg):
],
)
def test_set_move_speed(move_speed, axis, get_msg, return_msg):
controller = SmaractController(socket=SocketMock(host="dummy", port=123))
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
controller.sock.buffer_recv = return_msg
controller.set_closed_loop_move_speed(axis, move_speed)
@ -159,7 +166,8 @@ def test_set_move_speed(move_speed, axis, get_msg, return_msg):
],
)
def test_move_axis_to_absolute_position(pos, axis, hold_time, get_msg, return_msg):
controller = SmaractController(socket=SocketMock(host="dummy", port=123))
controller = SmaractController(socket_cls=SocketMock, socket_host="dummy", socket_port=123)
controller.on()
controller.sock.flush_buffer()
controller.sock.buffer_recv = return_msg
if hold_time is not None:
@ -232,6 +240,7 @@ def test_stop_axis(num_axes, get_msg, return_msg):
)
lsmarA.stage()
controller = lsmarA.controller
controller.on()
controller.sock.flush_buffer()
controller.sock.buffer_recv = return_msg
controller.stop_all_axes()