test(ids-camera): add tests for the IDSCamera integration
This commit is contained in:
@@ -12,7 +12,6 @@ from bec_lib.logger import bec_logger
|
||||
from ophyd import Component as Cpt
|
||||
from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase
|
||||
from ophyd_devices.utils.bec_signals import AsyncSignal, PreviewSignal
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
from csaxs_bec.devices.ids_cameras.base_integration.camera import Camera
|
||||
|
||||
@@ -24,39 +23,6 @@ if TYPE_CHECKING:
|
||||
logger = bec_logger.logger
|
||||
|
||||
|
||||
AnyDimShape = Tuple[int, ...]
|
||||
|
||||
|
||||
class ROISpec(TypedDict):
|
||||
"""
|
||||
Typed dictionary representing the specification for a Region of Interest (ROI).
|
||||
"""
|
||||
|
||||
x: int
|
||||
y: int
|
||||
width: int
|
||||
height: int
|
||||
img_shape: AnyDimShape
|
||||
mask: np.ndarray
|
||||
|
||||
|
||||
def validate_roi(roi: ROISpec) -> ROISpec:
|
||||
"""
|
||||
Validate the ROI specification to ensure it matches the image shape.
|
||||
|
||||
Args:
|
||||
roi (ROISpec): The ROI specification to validate.
|
||||
|
||||
Returns:
|
||||
ROISpec: The validated ROI specification.
|
||||
"""
|
||||
if roi["mask"].shape != roi["img_shape"]:
|
||||
raise ValueError(
|
||||
f"Mask shape {roi['mask'].shape} does not match image shape {roi['img_shape']}."
|
||||
)
|
||||
return roi
|
||||
|
||||
|
||||
class IDSCamera(PSIDeviceBase):
|
||||
"""IDS Camera class for cSAXS.
|
||||
|
||||
@@ -74,7 +40,7 @@ class IDSCamera(PSIDeviceBase):
|
||||
async_update={"type": "add", "max_shape": [None]},
|
||||
)
|
||||
|
||||
USER_ACCESS = ["live_mode", "roi", "get_last_image"]
|
||||
USER_ACCESS = ["live_mode", "roi", "set_rect_roi", "get_last_image"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -86,7 +52,6 @@ class IDSCamera(PSIDeviceBase):
|
||||
m_n_colormode: Literal[0, 1, 2, 3] = 1,
|
||||
bits_per_pixel: Literal[8, 24] = 24,
|
||||
live_mode: bool = False,
|
||||
roi: tuple[int, int, int, int] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize the IDS Camera.
|
||||
@@ -110,62 +75,32 @@ class IDSCamera(PSIDeviceBase):
|
||||
connect=False,
|
||||
)
|
||||
self._live_mode = False
|
||||
self._inputs = {"roi": roi if roi else (0, 0, 1, 1), "live_mode": live_mode}
|
||||
self._roi: ROISpec = validate_roi(
|
||||
ROISpec(
|
||||
{
|
||||
"x": 0,
|
||||
"y": 0,
|
||||
"width": 1,
|
||||
"height": 1,
|
||||
"img_shape": (1, 1),
|
||||
"mask": np.zeros((1, 1), dtype=np.uint8),
|
||||
}
|
||||
)
|
||||
)
|
||||
self._inputs = {"live_mode": live_mode}
|
||||
self._mask = np.zeros((1, 1), dtype=np.uint8)
|
||||
|
||||
############## Live Mode Methods ##############
|
||||
|
||||
@property
|
||||
def roi(self) -> ROISpec:
|
||||
def mask(self) -> np.ndarray:
|
||||
"""Return the current region of interest (ROI) for the camera."""
|
||||
return self._roi
|
||||
return self._mask
|
||||
|
||||
@roi.setter
|
||||
def roi(self, value: ROISpec | tuple[int, int, int, int] | list[int, int, int, int]):
|
||||
@mask.setter
|
||||
def mask(self, value: np.ndarray):
|
||||
"""
|
||||
Set the region of interest (ROI) for the camera.
|
||||
|
||||
Args:
|
||||
value (ROI | tuple[int, int, int, int] | list[int, int, int, int]): Either an ROI object, or a tuple or list with x, y, width, and height.
|
||||
value (np.ndarray): The mask to set as the ROI.
|
||||
"""
|
||||
if isinstance(value, (tuple, list)) and len(value) == 4:
|
||||
x = value[0]
|
||||
y = value[1]
|
||||
width = value[2]
|
||||
height = value[3]
|
||||
if x + width > self.cam.cam.width.value or y + height > self.cam.cam.height.value:
|
||||
raise ValueError("ROI exceeds camera dimensions.")
|
||||
img_shape = (self.cam.cam.height.value, self.cam.cam.width.value)
|
||||
mask = np.zeros(img_shape, dtype=np.uint8)
|
||||
mask[y : y + height, x : x + width] = 1
|
||||
value = validate_roi(
|
||||
ROISpec(
|
||||
{
|
||||
"x": x,
|
||||
"y": y,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"img_shape": img_shape,
|
||||
"mask": mask,
|
||||
}
|
||||
)
|
||||
if value.ndim != 2:
|
||||
raise ValueError("ROI mask must be a 2D array.")
|
||||
img_shape = (self.cam.cam.height.value, self.cam.cam.width.value)
|
||||
if value.shape[0] != img_shape[0] or value.shape[1] != img_shape[1]:
|
||||
raise ValueError(
|
||||
f"ROI mask shape {value.shape} does not match image shape {img_shape}."
|
||||
)
|
||||
if not isinstance(value, dict) or not all(
|
||||
key in value for key in ["x", "y", "width", "height", "img_shape", "mask"]
|
||||
):
|
||||
raise TypeError(f"ROI must be an instance of ROISpec {value}.")
|
||||
self._roi = value
|
||||
self._mask = value
|
||||
|
||||
@property
|
||||
def live_mode(self) -> bool:
|
||||
@@ -184,6 +119,17 @@ class IDSCamera(PSIDeviceBase):
|
||||
else:
|
||||
self._stop_live()
|
||||
|
||||
def set_rect_roi(self, x: int, y: int, width: int, height: int):
|
||||
"""Set the rectangular region of interest (ROI) for the camera."""
|
||||
if x < 0 or y < 0 or width <= 0 or height <= 0:
|
||||
raise ValueError("ROI coordinates and dimensions must be positive integers.")
|
||||
img_shape = (self.cam.cam.height.value, self.cam.cam.width.value)
|
||||
if x + width > img_shape[0] or y + height > img_shape[1]:
|
||||
raise ValueError("ROI exceeds camera dimensions.")
|
||||
mask = np.zeros(img_shape, dtype=np.uint8)
|
||||
mask[y : y + height, x : x + width] = 1
|
||||
self.mask = mask
|
||||
|
||||
def _start_live(self):
|
||||
"""Start the live mode for the camera."""
|
||||
if self._live_mode_thread is not None:
|
||||
@@ -225,6 +171,7 @@ class IDSCamera(PSIDeviceBase):
|
||||
self.image.put(image)
|
||||
|
||||
def get_last_image(self) -> np.ndarray:
|
||||
"""Get the last captured image from the camera."""
|
||||
image = self.image.get()
|
||||
if image:
|
||||
return image.data
|
||||
@@ -235,11 +182,7 @@ class IDSCamera(PSIDeviceBase):
|
||||
"""Connect to the camera."""
|
||||
self.cam.on_connect()
|
||||
self.live_mode = self._inputs.get("live_mode", None)
|
||||
roi = self._inputs.get("roi", None)
|
||||
if roi is None or not isinstance(roi, (tuple, list)) or not len(roi) == 4:
|
||||
# If ROI is not set, use the full camera resolution
|
||||
roi = (0, 0, self.cam.cam.width.value, self.cam.cam.height.value)
|
||||
self.roi = roi
|
||||
self.set_rect_roi(0, 0, self.cam.cam.width.value, self.cam.cam.height.value)
|
||||
|
||||
def on_destroy(self):
|
||||
"""Clean up resources when the device is destroyed."""
|
||||
@@ -253,7 +196,7 @@ class IDSCamera(PSIDeviceBase):
|
||||
image = self.image.get()
|
||||
if image is not None:
|
||||
image: messages.DevicePreviewMessage
|
||||
if self.roi["img_shape"][0:2] != image.data.shape[0:2]:
|
||||
if self.mask.shape[0:2] != image.data.shape[0:2]:
|
||||
logger.info(
|
||||
f"ROI shape does not match image shape, skipping ROI application for device {self.name}."
|
||||
)
|
||||
@@ -261,18 +204,16 @@ class IDSCamera(PSIDeviceBase):
|
||||
|
||||
if len(image.data.shape) == 3:
|
||||
# If the image has multiple channels, apply the mask to each channel
|
||||
data = (
|
||||
image.data * self.roi["mask"][:, :, np.newaxis]
|
||||
) # Apply mask to the image data
|
||||
data = image.data * self.mask[:, :, np.newaxis] # Apply mask to the image data
|
||||
n_channels = 3
|
||||
else:
|
||||
data = image.data * self.roi["mask"]
|
||||
data = image.data * self.mask
|
||||
n_channels = 1
|
||||
self.roi_signal.put(
|
||||
{
|
||||
self.roi_signal.name: {
|
||||
"value": np.sum(data)
|
||||
/ (np.sum(self.roi["mask"]) * n_channels), # TODO can be optimized
|
||||
/ (np.sum(self.mask) * n_channels), # TODO could be optimized
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
}
|
||||
|
||||
88
tests/tests_devices/test_ids_camera.py
Normal file
88
tests/tests_devices/test_ids_camera.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Unit tests for the IDS Camera device."""
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from csaxs_bec.devices.ids_cameras.ids_camera_new import IDSCamera
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def ids_camera():
|
||||
"""Fixture for creating an instance of the IDSCamera."""
|
||||
camera = IDSCamera(
|
||||
name="test_camera",
|
||||
camera_id=1,
|
||||
prefix="test:",
|
||||
scan_info=None,
|
||||
m_n_colormode=1,
|
||||
bits_per_pixel=24,
|
||||
live_mode=False,
|
||||
)
|
||||
# Mock camera connection and attributes
|
||||
camera.cam = mock.Mock()
|
||||
camera.cam._connected = True
|
||||
camera.cam.cam = mock.Mock()
|
||||
camera.cam.cam.width.value = 2
|
||||
camera.cam.cam.height.value = 2
|
||||
yield camera
|
||||
|
||||
|
||||
def test_mask_setter_getter(ids_camera):
|
||||
"""Test the mask setter and getter methods."""
|
||||
mask = np.zeros((2, 2), dtype=np.uint8)
|
||||
mask[0, 0] = 1
|
||||
ids_camera.mask = mask
|
||||
assert np.array_equal(ids_camera.mask, mask)
|
||||
|
||||
|
||||
def test_mask_setter_invalid_shape(ids_camera):
|
||||
"""Test the mask setter with an invalid shape."""
|
||||
with pytest.raises(ValueError):
|
||||
ids_camera.mask = np.zeros((3, 3), dtype=np.uint8) # Exceeds mocked camera dimensions
|
||||
|
||||
|
||||
def test_on_connected_sets_mask_and_live_mode(ids_camera):
|
||||
"""Test the on_connected method to ensure it sets the mask and live mode."""
|
||||
ids_camera.cam.on_connect = mock.Mock()
|
||||
ids_camera.on_connected()
|
||||
ids_camera.cam.on_connect.assert_called_once()
|
||||
expected_mask = np.ones((2, 2), dtype=np.uint8)
|
||||
assert np.array_equal(ids_camera.mask, expected_mask)
|
||||
|
||||
|
||||
def test_on_trigger_roi_signal(ids_camera):
|
||||
"""Test the on_trigger method to ensure it processes the ROI signal correctly."""
|
||||
ids_camera.live_mode = True
|
||||
test_image = np.array([[2, 4], [6, 8]])
|
||||
test_mask = np.array([[1, 0], [0, 1]], dtype=np.uint8)
|
||||
ids_camera.mask = test_mask
|
||||
mock_image = mock.Mock()
|
||||
mock_image.data = test_image
|
||||
ids_camera.image.get = mock.Mock(return_value=mock_image)
|
||||
ids_camera.roi_signal.put = mock.Mock(side_effect=ids_camera.roi_signal.put)
|
||||
ids_camera.on_trigger()
|
||||
expected_value = (2 * 1 + 4 * 0 + 6 * 0 + 8 * 1) / (np.sum(test_mask) * 1)
|
||||
result = ids_camera.roi_signal.get()
|
||||
assert np.isclose(
|
||||
result.content["signals"][ids_camera.roi_signal.name]["value"], expected_value, atol=1e-6
|
||||
)
|
||||
|
||||
|
||||
def test_get_last_image(ids_camera):
|
||||
"""Test the get_last_image method to ensure it returns the last captured image."""
|
||||
test_image = np.array([[1, 2], [3, 4]], dtype=np.uint8)
|
||||
mock_image = mock.Mock()
|
||||
mock_image.data = test_image
|
||||
ids_camera.image.get = mock.Mock(return_value=mock_image)
|
||||
|
||||
result = ids_camera.get_last_image()
|
||||
assert np.array_equal(result, test_image)
|
||||
|
||||
|
||||
def test_on_destroy(ids_camera):
|
||||
"""Test the on_destroy method to ensure it cleans up resources."""
|
||||
ids_camera.cam.on_disconnect = mock.Mock()
|
||||
ids_camera.on_destroy()
|
||||
ids_camera.cam.on_disconnect.assert_called_once()
|
||||
Reference in New Issue
Block a user