diff --git a/ophyd_devices/utils/http_signal.py b/ophyd_devices/utils/http_signal.py new file mode 100644 index 0000000..10933dd --- /dev/null +++ b/ophyd_devices/utils/http_signal.py @@ -0,0 +1,59 @@ +from typing import Any + +from ophyd.utils.errors import ReadOnlyError +from requests import Response, get, put + +from ophyd_devices.utils.socket import SocketSignal + + +class HttpRestError(Exception): + """Error for rest calls from a HttpRestSignal.""" + + def __init__(self, resp: Response, *args: object, value: Any | None = None) -> None: + method, url = resp.request.method, resp.request.url + data = f"{str(value)} to " if value is not None else "" + super().__init__( + f"Could not {method} {data}{url}. Code: {resp.status_code}. Reason: {resp.reason}.", + *args, + ) + + +class HttpRestSignal(SocketSignal): + """Ophyd signal which gets and puts to a REST API rather than EPICS PVs.""" + + def __init__(self, *args, get_uri: str = "", put_uri: str | None = None, **kwargs): + self._get_uri = get_uri + self._put_uri = put_uri or get_uri + super().__init__(*args, **kwargs) + + def _get_uri_transform(self, uri: str): + """Hook to apply to the GET uri before creating the request""" + return uri + + def _put_transform(self, uri: str, val: Any): + """Hook to apply to the PUT uri and data before creating the request""" + return uri, val + + def _socket_get(self): + resp = get(self._get_uri) + if not resp.ok: + raise HttpRestError(resp) + self._readback = resp.text + return self._readback + + def _socket_set(self, val: Any): + uri, data = self._put_transform(self._put_uri, val) + resp = put(uri, data=data) + if not resp.ok: + raise HttpRestError(resp, value=data) + + +class HttpRestSignalRO(HttpRestSignal): + """Read-only version of HttpRestSignal""" + + def __init__(self, *args, get_uri: str = "", **kwargs): + self._get_uri = get_uri + super().__init__(*args, **kwargs) + + def _socket_set(self, val): + raise ReadOnlyError(f"HttpRestSignalRO {self.name} is read-only!") diff --git a/ophyd_devices/utils/socket.py b/ophyd_devices/utils/socket.py index ec57a77..5f13fa1 100644 --- a/ophyd_devices/utils/socket.py +++ b/ophyd_devices/utils/socket.py @@ -113,7 +113,7 @@ class SocketSignal(abc.ABC, Signal): self._last_readback = 0 @abc.abstractmethod - def _socket_get(self): ... + def _socket_get(self) -> typing.Any: ... @abc.abstractmethod def _socket_set(self, val): ... diff --git a/pyproject.toml b/pyproject.toml index 5fc3f5a..b3f7c64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dev = [ "coverage~=7.0", "pylint~=3.0", "pytest-random-order~=1.1", + "requests-mock", ] [project.scripts] diff --git a/tests/test_http_signals.py b/tests/test_http_signals.py new file mode 100644 index 0000000..72535ec --- /dev/null +++ b/tests/test_http_signals.py @@ -0,0 +1,70 @@ +from unittest.mock import ANY + +import pytest +import requests_mock + +from ophyd_devices.utils.http_signal import HttpRestError, HttpRestSignal + + +@pytest.fixture(autouse=True) +def mock_server(): + with requests_mock.Mocker() as m: + mock_data = "data" + + def get_cb(request, context): + nonlocal mock_data + return mock_data + + def put_cb(request, context): + nonlocal mock_data + mock_data = request.text + + def put_req_valid(request): + try: + val = int(request.text) + except: + return False + return -50 < val < 50 + + def put_can_fail_cb(request, context): + context.reason = "" if put_req_valid(request) else "out of range" + context.status_code = 202 if put_req_valid(request) else 422 + + m.get("http://test.psi.ch/get_data", text=get_cb) + m.put("http://test.psi.ch/put_data", text=put_cb) + + m.get("http://test.psi.ch/bad_get_endpoint", status_code=404, reason="test not found") + m.put("http://test.psi.ch/put_can_fail", text=put_can_fail_cb) + + yield requests_mock + + +def test_signal_get(): + sig = HttpRestSignal(name="get", get_uri="http://test.psi.ch/get_data") + assert sig.read() == {"get": {"timestamp": ANY, "value": "data"}} + + +def test_signal_put(): + sig = HttpRestSignal( + name="put_get", get_uri="http://test.psi.ch/get_data", put_uri="http://test.psi.ch/put_data" + ) + assert sig.read() == {"put_get": {"timestamp": ANY, "value": "data"}} + sig.put("test_value") + assert sig.read() == {"put_get": {"timestamp": ANY, "value": "test_value"}} + + +def test_bad_signal_get(): + sig = HttpRestSignal(name="get", get_uri="http://test.psi.ch/bad_get_endpoint") + with pytest.raises(HttpRestError) as e: + sig.read() + assert e.match("test not found") + + +def test_bad_signal_put(): + sig = HttpRestSignal(name="get", get_uri="http://test.psi.ch/put_can_fail") + sig.put("20") + + with pytest.raises(HttpRestError) as e: + sig.put("50") + assert e.match("Could not PUT 50") + assert e.match("Code: 422. Reason: out of range.")