diff --git a/ophyd_devices/utils/controller.py b/ophyd_devices/utils/controller.py index 67f293b..18e9364 100644 --- a/ophyd_devices/utils/controller.py +++ b/ophyd_devices/utils/controller.py @@ -93,34 +93,48 @@ class Controller(OphydObject): self._socket_port = socket_port @threadlocked - def socket_put(self, val: str): + def socket_put(self, val: str, socket=None) -> None: """ Send a command to the controller through the socket. Args: val (str): Command to send + socket (socket.socket): Socket object to use. If None, the default socket is used. """ - self.sock.put(f"{val}\n".encode()) + sock = socket or self.sock + sock.put(f"{val}\n".encode()) @threadlocked - def socket_get(self): + def socket_get(self, socket=None) -> str: """ Receive a response from the controller through the socket. + + Args: + socket (socket.socket): Socket object to use. If None, the default socket is used. + + Returns: + str: Response from the controller """ - return self.sock.receive().decode() + sock = socket or self.sock + return sock.receive().decode() @retry_once @threadlocked - def socket_put_and_receive(self, val: str, remove_trailing_chars=True) -> str: + def socket_put_and_receive(self, val: str, remove_trailing_chars=True, socket=None) -> str: """ Send a command to the controller and receive the response. Override this method in the derived class if necessary, especially if the response needs to be parsed differently. + + Args: + val (str): Command to send + remove_trailing_chars (bool): Remove trailing characters from the response + socket (socket.socket): Socket object to use. If None, the default socket is used. """ - self.socket_put(val) + self.socket_put(val, socket) if remove_trailing_chars: return self._remove_trailing_characters(self.sock.receive().decode()) - return self.socket_get() + return self.socket_get(socket) def _remove_trailing_characters(self, var) -> str: if len(var) > 1: