diff --git a/csaxs_bec/devices/ids_cameras/ids_camera_new.py b/csaxs_bec/devices/ids_cameras/ids_camera_new.py index 0fe2e5e..17abdab 100644 --- a/csaxs_bec/devices/ids_cameras/ids_camera_new.py +++ b/csaxs_bec/devices/ids_cameras/ids_camera_new.py @@ -12,7 +12,6 @@ from bec_lib.logger import bec_logger from ophyd import Component as Cpt from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase from ophyd_devices.utils.bec_signals import AsyncSignal, PreviewSignal -from pydantic import BaseModel, field_validator from csaxs_bec.devices.ids_cameras.base_integration.camera import Camera @@ -24,39 +23,6 @@ if TYPE_CHECKING: logger = bec_logger.logger -AnyDimShape = Tuple[int, ...] - - -class ROISpec(TypedDict): - """ - Typed dictionary representing the specification for a Region of Interest (ROI). - """ - - x: int - y: int - width: int - height: int - img_shape: AnyDimShape - mask: np.ndarray - - -def validate_roi(roi: ROISpec) -> ROISpec: - """ - Validate the ROI specification to ensure it matches the image shape. - - Args: - roi (ROISpec): The ROI specification to validate. - - Returns: - ROISpec: The validated ROI specification. - """ - if roi["mask"].shape != roi["img_shape"]: - raise ValueError( - f"Mask shape {roi['mask'].shape} does not match image shape {roi['img_shape']}." - ) - return roi - - class IDSCamera(PSIDeviceBase): """IDS Camera class for cSAXS. @@ -74,7 +40,7 @@ class IDSCamera(PSIDeviceBase): async_update={"type": "add", "max_shape": [None]}, ) - USER_ACCESS = ["live_mode", "roi", "get_last_image"] + USER_ACCESS = ["live_mode", "roi", "set_rect_roi", "get_last_image"] def __init__( self, @@ -86,7 +52,6 @@ class IDSCamera(PSIDeviceBase): m_n_colormode: Literal[0, 1, 2, 3] = 1, bits_per_pixel: Literal[8, 24] = 24, live_mode: bool = False, - roi: tuple[int, int, int, int] | None = None, **kwargs, ): """Initialize the IDS Camera. @@ -110,62 +75,32 @@ class IDSCamera(PSIDeviceBase): connect=False, ) self._live_mode = False - self._inputs = {"roi": roi if roi else (0, 0, 1, 1), "live_mode": live_mode} - self._roi: ROISpec = validate_roi( - ROISpec( - { - "x": 0, - "y": 0, - "width": 1, - "height": 1, - "img_shape": (1, 1), - "mask": np.zeros((1, 1), dtype=np.uint8), - } - ) - ) + self._inputs = {"live_mode": live_mode} + self._mask = np.zeros((1, 1), dtype=np.uint8) ############## Live Mode Methods ############## @property - def roi(self) -> ROISpec: + def mask(self) -> np.ndarray: """Return the current region of interest (ROI) for the camera.""" - return self._roi + return self._mask - @roi.setter - def roi(self, value: ROISpec | tuple[int, int, int, int] | list[int, int, int, int]): + @mask.setter + def mask(self, value: np.ndarray): """ Set the region of interest (ROI) for the camera. Args: - value (ROI | tuple[int, int, int, int] | list[int, int, int, int]): Either an ROI object, or a tuple or list with x, y, width, and height. + value (np.ndarray): The mask to set as the ROI. """ - if isinstance(value, (tuple, list)) and len(value) == 4: - x = value[0] - y = value[1] - width = value[2] - height = value[3] - if x + width > self.cam.cam.width.value or y + height > self.cam.cam.height.value: - raise ValueError("ROI exceeds camera dimensions.") - img_shape = (self.cam.cam.height.value, self.cam.cam.width.value) - mask = np.zeros(img_shape, dtype=np.uint8) - mask[y : y + height, x : x + width] = 1 - value = validate_roi( - ROISpec( - { - "x": x, - "y": y, - "width": width, - "height": height, - "img_shape": img_shape, - "mask": mask, - } - ) + if value.ndim != 2: + raise ValueError("ROI mask must be a 2D array.") + img_shape = (self.cam.cam.height.value, self.cam.cam.width.value) + if value.shape[0] != img_shape[0] or value.shape[1] != img_shape[1]: + raise ValueError( + f"ROI mask shape {value.shape} does not match image shape {img_shape}." ) - if not isinstance(value, dict) or not all( - key in value for key in ["x", "y", "width", "height", "img_shape", "mask"] - ): - raise TypeError(f"ROI must be an instance of ROISpec {value}.") - self._roi = value + self._mask = value @property def live_mode(self) -> bool: @@ -184,6 +119,17 @@ class IDSCamera(PSIDeviceBase): else: self._stop_live() + def set_rect_roi(self, x: int, y: int, width: int, height: int): + """Set the rectangular region of interest (ROI) for the camera.""" + if x < 0 or y < 0 or width <= 0 or height <= 0: + raise ValueError("ROI coordinates and dimensions must be positive integers.") + img_shape = (self.cam.cam.height.value, self.cam.cam.width.value) + if x + width > img_shape[0] or y + height > img_shape[1]: + raise ValueError("ROI exceeds camera dimensions.") + mask = np.zeros(img_shape, dtype=np.uint8) + mask[y : y + height, x : x + width] = 1 + self.mask = mask + def _start_live(self): """Start the live mode for the camera.""" if self._live_mode_thread is not None: @@ -225,6 +171,7 @@ class IDSCamera(PSIDeviceBase): self.image.put(image) def get_last_image(self) -> np.ndarray: + """Get the last captured image from the camera.""" image = self.image.get() if image: return image.data @@ -235,11 +182,7 @@ class IDSCamera(PSIDeviceBase): """Connect to the camera.""" self.cam.on_connect() self.live_mode = self._inputs.get("live_mode", None) - roi = self._inputs.get("roi", None) - if roi is None or not isinstance(roi, (tuple, list)) or not len(roi) == 4: - # If ROI is not set, use the full camera resolution - roi = (0, 0, self.cam.cam.width.value, self.cam.cam.height.value) - self.roi = roi + self.set_rect_roi(0, 0, self.cam.cam.width.value, self.cam.cam.height.value) def on_destroy(self): """Clean up resources when the device is destroyed.""" @@ -253,7 +196,7 @@ class IDSCamera(PSIDeviceBase): image = self.image.get() if image is not None: image: messages.DevicePreviewMessage - if self.roi["img_shape"][0:2] != image.data.shape[0:2]: + if self.mask.shape[0:2] != image.data.shape[0:2]: logger.info( f"ROI shape does not match image shape, skipping ROI application for device {self.name}." ) @@ -261,18 +204,16 @@ class IDSCamera(PSIDeviceBase): if len(image.data.shape) == 3: # If the image has multiple channels, apply the mask to each channel - data = ( - image.data * self.roi["mask"][:, :, np.newaxis] - ) # Apply mask to the image data + data = image.data * self.mask[:, :, np.newaxis] # Apply mask to the image data n_channels = 3 else: - data = image.data * self.roi["mask"] + data = image.data * self.mask n_channels = 1 self.roi_signal.put( { self.roi_signal.name: { "value": np.sum(data) - / (np.sum(self.roi["mask"]) * n_channels), # TODO can be optimized + / (np.sum(self.mask) * n_channels), # TODO could be optimized "timestamp": time.time(), } } diff --git a/tests/tests_devices/test_ids_camera.py b/tests/tests_devices/test_ids_camera.py new file mode 100644 index 0000000..a1b51dc --- /dev/null +++ b/tests/tests_devices/test_ids_camera.py @@ -0,0 +1,88 @@ +"""Unit tests for the IDS Camera device.""" + +from unittest import mock + +import numpy as np +import pytest + +from csaxs_bec.devices.ids_cameras.ids_camera_new import IDSCamera + + +@pytest.fixture(scope="function") +def ids_camera(): + """Fixture for creating an instance of the IDSCamera.""" + camera = IDSCamera( + name="test_camera", + camera_id=1, + prefix="test:", + scan_info=None, + m_n_colormode=1, + bits_per_pixel=24, + live_mode=False, + ) + # Mock camera connection and attributes + camera.cam = mock.Mock() + camera.cam._connected = True + camera.cam.cam = mock.Mock() + camera.cam.cam.width.value = 2 + camera.cam.cam.height.value = 2 + yield camera + + +def test_mask_setter_getter(ids_camera): + """Test the mask setter and getter methods.""" + mask = np.zeros((2, 2), dtype=np.uint8) + mask[0, 0] = 1 + ids_camera.mask = mask + assert np.array_equal(ids_camera.mask, mask) + + +def test_mask_setter_invalid_shape(ids_camera): + """Test the mask setter with an invalid shape.""" + with pytest.raises(ValueError): + ids_camera.mask = np.zeros((3, 3), dtype=np.uint8) # Exceeds mocked camera dimensions + + +def test_on_connected_sets_mask_and_live_mode(ids_camera): + """Test the on_connected method to ensure it sets the mask and live mode.""" + ids_camera.cam.on_connect = mock.Mock() + ids_camera.on_connected() + ids_camera.cam.on_connect.assert_called_once() + expected_mask = np.ones((2, 2), dtype=np.uint8) + assert np.array_equal(ids_camera.mask, expected_mask) + + +def test_on_trigger_roi_signal(ids_camera): + """Test the on_trigger method to ensure it processes the ROI signal correctly.""" + ids_camera.live_mode = True + test_image = np.array([[2, 4], [6, 8]]) + test_mask = np.array([[1, 0], [0, 1]], dtype=np.uint8) + ids_camera.mask = test_mask + mock_image = mock.Mock() + mock_image.data = test_image + ids_camera.image.get = mock.Mock(return_value=mock_image) + ids_camera.roi_signal.put = mock.Mock(side_effect=ids_camera.roi_signal.put) + ids_camera.on_trigger() + expected_value = (2 * 1 + 4 * 0 + 6 * 0 + 8 * 1) / (np.sum(test_mask) * 1) + result = ids_camera.roi_signal.get() + assert np.isclose( + result.content["signals"][ids_camera.roi_signal.name]["value"], expected_value, atol=1e-6 + ) + + +def test_get_last_image(ids_camera): + """Test the get_last_image method to ensure it returns the last captured image.""" + test_image = np.array([[1, 2], [3, 4]], dtype=np.uint8) + mock_image = mock.Mock() + mock_image.data = test_image + ids_camera.image.get = mock.Mock(return_value=mock_image) + + result = ids_camera.get_last_image() + assert np.array_equal(result, test_image) + + +def test_on_destroy(ids_camera): + """Test the on_destroy method to ensure it cleans up resources.""" + ids_camera.cam.on_disconnect = mock.Mock() + ids_camera.on_destroy() + ids_camera.cam.on_disconnect.assert_called_once()