test(ids-camera): add tests for the IDSCamera integration

This commit is contained in:
2025-08-05 10:33:29 +02:00
parent 8f7ada2f92
commit 9e45e927a0
2 changed files with 120 additions and 91 deletions

View File

@@ -12,7 +12,6 @@ from bec_lib.logger import bec_logger
from ophyd import Component as Cpt
from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase
from ophyd_devices.utils.bec_signals import AsyncSignal, PreviewSignal
from pydantic import BaseModel, field_validator
from csaxs_bec.devices.ids_cameras.base_integration.camera import Camera
@@ -24,39 +23,6 @@ if TYPE_CHECKING:
logger = bec_logger.logger
AnyDimShape = Tuple[int, ...]
class ROISpec(TypedDict):
"""
Typed dictionary representing the specification for a Region of Interest (ROI).
"""
x: int
y: int
width: int
height: int
img_shape: AnyDimShape
mask: np.ndarray
def validate_roi(roi: ROISpec) -> ROISpec:
"""
Validate the ROI specification to ensure it matches the image shape.
Args:
roi (ROISpec): The ROI specification to validate.
Returns:
ROISpec: The validated ROI specification.
"""
if roi["mask"].shape != roi["img_shape"]:
raise ValueError(
f"Mask shape {roi['mask'].shape} does not match image shape {roi['img_shape']}."
)
return roi
class IDSCamera(PSIDeviceBase):
"""IDS Camera class for cSAXS.
@@ -74,7 +40,7 @@ class IDSCamera(PSIDeviceBase):
async_update={"type": "add", "max_shape": [None]},
)
USER_ACCESS = ["live_mode", "roi", "get_last_image"]
USER_ACCESS = ["live_mode", "roi", "set_rect_roi", "get_last_image"]
def __init__(
self,
@@ -86,7 +52,6 @@ class IDSCamera(PSIDeviceBase):
m_n_colormode: Literal[0, 1, 2, 3] = 1,
bits_per_pixel: Literal[8, 24] = 24,
live_mode: bool = False,
roi: tuple[int, int, int, int] | None = None,
**kwargs,
):
"""Initialize the IDS Camera.
@@ -110,62 +75,32 @@ class IDSCamera(PSIDeviceBase):
connect=False,
)
self._live_mode = False
self._inputs = {"roi": roi if roi else (0, 0, 1, 1), "live_mode": live_mode}
self._roi: ROISpec = validate_roi(
ROISpec(
{
"x": 0,
"y": 0,
"width": 1,
"height": 1,
"img_shape": (1, 1),
"mask": np.zeros((1, 1), dtype=np.uint8),
}
)
)
self._inputs = {"live_mode": live_mode}
self._mask = np.zeros((1, 1), dtype=np.uint8)
############## Live Mode Methods ##############
@property
def roi(self) -> ROISpec:
def mask(self) -> np.ndarray:
"""Return the current region of interest (ROI) for the camera."""
return self._roi
return self._mask
@roi.setter
def roi(self, value: ROISpec | tuple[int, int, int, int] | list[int, int, int, int]):
@mask.setter
def mask(self, value: np.ndarray):
"""
Set the region of interest (ROI) for the camera.
Args:
value (ROI | tuple[int, int, int, int] | list[int, int, int, int]): Either an ROI object, or a tuple or list with x, y, width, and height.
value (np.ndarray): The mask to set as the ROI.
"""
if isinstance(value, (tuple, list)) and len(value) == 4:
x = value[0]
y = value[1]
width = value[2]
height = value[3]
if x + width > self.cam.cam.width.value or y + height > self.cam.cam.height.value:
raise ValueError("ROI exceeds camera dimensions.")
img_shape = (self.cam.cam.height.value, self.cam.cam.width.value)
mask = np.zeros(img_shape, dtype=np.uint8)
mask[y : y + height, x : x + width] = 1
value = validate_roi(
ROISpec(
{
"x": x,
"y": y,
"width": width,
"height": height,
"img_shape": img_shape,
"mask": mask,
}
)
if value.ndim != 2:
raise ValueError("ROI mask must be a 2D array.")
img_shape = (self.cam.cam.height.value, self.cam.cam.width.value)
if value.shape[0] != img_shape[0] or value.shape[1] != img_shape[1]:
raise ValueError(
f"ROI mask shape {value.shape} does not match image shape {img_shape}."
)
if not isinstance(value, dict) or not all(
key in value for key in ["x", "y", "width", "height", "img_shape", "mask"]
):
raise TypeError(f"ROI must be an instance of ROISpec {value}.")
self._roi = value
self._mask = value
@property
def live_mode(self) -> bool:
@@ -184,6 +119,17 @@ class IDSCamera(PSIDeviceBase):
else:
self._stop_live()
def set_rect_roi(self, x: int, y: int, width: int, height: int):
"""Set the rectangular region of interest (ROI) for the camera."""
if x < 0 or y < 0 or width <= 0 or height <= 0:
raise ValueError("ROI coordinates and dimensions must be positive integers.")
img_shape = (self.cam.cam.height.value, self.cam.cam.width.value)
if x + width > img_shape[0] or y + height > img_shape[1]:
raise ValueError("ROI exceeds camera dimensions.")
mask = np.zeros(img_shape, dtype=np.uint8)
mask[y : y + height, x : x + width] = 1
self.mask = mask
def _start_live(self):
"""Start the live mode for the camera."""
if self._live_mode_thread is not None:
@@ -225,6 +171,7 @@ class IDSCamera(PSIDeviceBase):
self.image.put(image)
def get_last_image(self) -> np.ndarray:
"""Get the last captured image from the camera."""
image = self.image.get()
if image:
return image.data
@@ -235,11 +182,7 @@ class IDSCamera(PSIDeviceBase):
"""Connect to the camera."""
self.cam.on_connect()
self.live_mode = self._inputs.get("live_mode", None)
roi = self._inputs.get("roi", None)
if roi is None or not isinstance(roi, (tuple, list)) or not len(roi) == 4:
# If ROI is not set, use the full camera resolution
roi = (0, 0, self.cam.cam.width.value, self.cam.cam.height.value)
self.roi = roi
self.set_rect_roi(0, 0, self.cam.cam.width.value, self.cam.cam.height.value)
def on_destroy(self):
"""Clean up resources when the device is destroyed."""
@@ -253,7 +196,7 @@ class IDSCamera(PSIDeviceBase):
image = self.image.get()
if image is not None:
image: messages.DevicePreviewMessage
if self.roi["img_shape"][0:2] != image.data.shape[0:2]:
if self.mask.shape[0:2] != image.data.shape[0:2]:
logger.info(
f"ROI shape does not match image shape, skipping ROI application for device {self.name}."
)
@@ -261,18 +204,16 @@ class IDSCamera(PSIDeviceBase):
if len(image.data.shape) == 3:
# If the image has multiple channels, apply the mask to each channel
data = (
image.data * self.roi["mask"][:, :, np.newaxis]
) # Apply mask to the image data
data = image.data * self.mask[:, :, np.newaxis] # Apply mask to the image data
n_channels = 3
else:
data = image.data * self.roi["mask"]
data = image.data * self.mask
n_channels = 1
self.roi_signal.put(
{
self.roi_signal.name: {
"value": np.sum(data)
/ (np.sum(self.roi["mask"]) * n_channels), # TODO can be optimized
/ (np.sum(self.mask) * n_channels), # TODO could be optimized
"timestamp": time.time(),
}
}

View File

@@ -0,0 +1,88 @@
"""Unit tests for the IDS Camera device."""
from unittest import mock
import numpy as np
import pytest
from csaxs_bec.devices.ids_cameras.ids_camera_new import IDSCamera
@pytest.fixture(scope="function")
def ids_camera():
"""Fixture for creating an instance of the IDSCamera."""
camera = IDSCamera(
name="test_camera",
camera_id=1,
prefix="test:",
scan_info=None,
m_n_colormode=1,
bits_per_pixel=24,
live_mode=False,
)
# Mock camera connection and attributes
camera.cam = mock.Mock()
camera.cam._connected = True
camera.cam.cam = mock.Mock()
camera.cam.cam.width.value = 2
camera.cam.cam.height.value = 2
yield camera
def test_mask_setter_getter(ids_camera):
"""Test the mask setter and getter methods."""
mask = np.zeros((2, 2), dtype=np.uint8)
mask[0, 0] = 1
ids_camera.mask = mask
assert np.array_equal(ids_camera.mask, mask)
def test_mask_setter_invalid_shape(ids_camera):
"""Test the mask setter with an invalid shape."""
with pytest.raises(ValueError):
ids_camera.mask = np.zeros((3, 3), dtype=np.uint8) # Exceeds mocked camera dimensions
def test_on_connected_sets_mask_and_live_mode(ids_camera):
"""Test the on_connected method to ensure it sets the mask and live mode."""
ids_camera.cam.on_connect = mock.Mock()
ids_camera.on_connected()
ids_camera.cam.on_connect.assert_called_once()
expected_mask = np.ones((2, 2), dtype=np.uint8)
assert np.array_equal(ids_camera.mask, expected_mask)
def test_on_trigger_roi_signal(ids_camera):
"""Test the on_trigger method to ensure it processes the ROI signal correctly."""
ids_camera.live_mode = True
test_image = np.array([[2, 4], [6, 8]])
test_mask = np.array([[1, 0], [0, 1]], dtype=np.uint8)
ids_camera.mask = test_mask
mock_image = mock.Mock()
mock_image.data = test_image
ids_camera.image.get = mock.Mock(return_value=mock_image)
ids_camera.roi_signal.put = mock.Mock(side_effect=ids_camera.roi_signal.put)
ids_camera.on_trigger()
expected_value = (2 * 1 + 4 * 0 + 6 * 0 + 8 * 1) / (np.sum(test_mask) * 1)
result = ids_camera.roi_signal.get()
assert np.isclose(
result.content["signals"][ids_camera.roi_signal.name]["value"], expected_value, atol=1e-6
)
def test_get_last_image(ids_camera):
"""Test the get_last_image method to ensure it returns the last captured image."""
test_image = np.array([[1, 2], [3, 4]], dtype=np.uint8)
mock_image = mock.Mock()
mock_image.data = test_image
ids_camera.image.get = mock.Mock(return_value=mock_image)
result = ids_camera.get_last_image()
assert np.array_equal(result, test_image)
def test_on_destroy(ids_camera):
"""Test the on_destroy method to ensure it cleans up resources."""
ids_camera.cam.on_disconnect = mock.Mock()
ids_camera.on_destroy()
ids_camera.cam.on_disconnect.assert_called_once()