test: cleanup, refactor and add tests

This commit is contained in:
appel_c 2024-12-06 14:11:31 +01:00
parent 1c225aabfe
commit c48a97d9bc
3 changed files with 158 additions and 21 deletions

View File

@ -502,3 +502,30 @@ def test_interactive_scan(bec_client_lib):
while len(report.scan.live_data) != 10: while len(report.scan.live_data) != 10:
time.sleep(0.1) time.sleep(0.1)
assert len(report.scan.live_data.samx.samx.val) == 10 assert len(report.scan.live_data.samx.samx.val) == 10
def test_image_analysis(bec_client_lib):
bec = bec_client_lib
bec.metadata.update({"unit_test": "test_image_analysis"})
dev = bec.device_manager.devices
scans = bec.scans
dev.eiger.sim.select_model("gaussian")
dev.eiger.sim.params = {
"amplitude": 100,
"center_offset": np.array([0, 0]),
"covariance": np.array([[1, 0], [0, 1]]),
"noise": "uniform",
"noise_multiplier": 10,
"hot_pixel_coords": np.array([[24, 24], [50, 20], [4, 40]]),
"hot_pixel_types": ["fluctuating", "constant", "fluctuating"],
"hot_pixel_values": np.array([1000.0, 10000.0, 1000.0]),
}
res = scans.line_scan(dev.samx, -5, 5, steps=10, relative=False, exp_time=0)
res.wait()
fit_res = bec.dap.image_analysis.run(res.scan.scan_id, "eiger")
assert (fit_res[1]["stats"]["max"] == 10000.0).all()
assert (fit_res[1]["stats"]["min"] == 0.0).all()
assert (np.isclose(fit_res[1]["stats"]["mean"], 3.3, atol=0.5)).all()
# Center of mass is not in the middle due to hot (fluctuating) pixels
assert (np.isclose(fit_res[1]["stats"]["center_of_mass"], [49.5, 40.8], atol=1)).all()

View File

@ -8,6 +8,7 @@ from scipy.ndimage import center_of_mass
from typeguard import typechecked from typeguard import typechecked
from bec_lib.device_monitor_plugin import DeviceMonitorPlugin from bec_lib.device_monitor_plugin import DeviceMonitorPlugin
from bec_lib.logger import bec_logger
from bec_server.data_processing.dap_service import DAPError, DAPServiceBase from bec_server.data_processing.dap_service import DAPError, DAPServiceBase
if TYPE_CHECKING: if TYPE_CHECKING:
@ -15,6 +16,8 @@ if TYPE_CHECKING:
from bec_lib.device import DeviceBase from bec_lib.device import DeviceBase
from bec_lib.scan_items import ScanItem from bec_lib.scan_items import ScanItem
logger = bec_logger.logger
class ReturnType(str, Enum): class ReturnType(str, Enum):
"""The possible return data types for the image analysis service.""" """The possible return data types for the image analysis service."""
@ -37,15 +40,14 @@ class ImageAnalysisService(DAPServiceBase):
self.device = None self.device = None
self.return_type = None self.return_type = None
def configure( def configure(self, scan_item=None, device=None, images=None, return_type=None, **kwargs):
self, *args, scan_item=None, device=None, images=None, return_type=None, **kwargs # TODO Add type hints for np.ndarray and list[np.ndarray] do not work yet in the signature_serializer
): # This will be adressed in a different MR, issue is created #395
# TODO somehow the serialisation of these arguments was noch properly working.
# It crashed on the ScanItem/DeviceBase/np.ndarray types. (I believe)
# scan_item: ScanItem | str = None, # scan_item: ScanItem | str = None,
# device: DeviceBase | str = None, # device: DeviceBase | str = None,
# images: np.ndarray | list[np.ndarray] | None = None, # images: np.ndarray | list[np.ndarray] | None = None,
# **kwargs, # **kwargs,
# ):
"""Configure the image analysis service. Either provide a scan item and a device which """Configure the image analysis service. Either provide a scan item and a device which
has a 2D monitor active, or provide the images directly. If no data is found for the input has a 2D monitor active, or provide the images directly. If no data is found for the input
the service will return an empty stream output. the service will return an empty stream output.
@ -56,31 +58,34 @@ class ImageAnalysisService(DAPServiceBase):
images: Alternatively, you can provide the images directly images: Alternatively, you can provide the images directly
return_type: The type of data to return, can be "min", "max", "mean", "median", "std", "center_of_mass" return_type: The type of data to return, can be "min", "max", "mean", "median", "std", "center_of_mass"
""" """
self.device = str(device)
if return_type is None: if return_type is None:
return_type = ReturnType.CENTER_OF_MASS return_type = ReturnType.CENTER_OF_MASS
else: else:
return_type = ReturnType(return_type) return_type = ReturnType(return_type)
self.return_type = return_type self.return_type = return_type
if images is None: # If images are provided, use them
self.data = self.get_images_for_scan_item(scan_item=scan_item) if images is not None:
if isinstance(images, np.ndarray):
self.data = [images]
elif isinstance(images, list) and all(
isinstance(image, np.ndarray) for image in images
):
self.data = images
return return
if isinstance(images, np.ndarray): # Else if scan item is provided, get the images
self.data = [images] if device is None or scan_item is None:
elif isinstance(images, list) and all(isinstance(image, np.ndarray) for image in images): raise DAPError(
self.data = images f"Either provide a device: {device} and scan_id {scan_item} or images {images}"
else: )
raise DAPError(f"Invalid format for images: {images} provided") self.device = str(device)
self.data = self.get_images_for_scan_item(scan_id=scan_item)
def get_images_for_scan_item(self, scan_item: ScanItem | str) -> list[np.ndarray]: def get_images_for_scan_item(self, scan_id: str) -> list[np.ndarray]:
"""Get the data for the scan item.""" """Get the data for the scan item."""
scan_id = scan_item self.scan_id = scan_id
if scan_id != self.scan_id or not self.current_scan_item:
scan_item = self.client.queue.scan_storage.find_scan_by_ID(scan_id)
self.scan_id = scan_id
else:
scan_item = self.current_scan_item
data = self.device_monitor_plugin.get_data_for_scan(device=self.device, scan=self.scan_id) data = self.device_monitor_plugin.get_data_for_scan(device=self.device, scan=self.scan_id)
if len(data) == 0:
logger.warning(f"No data found for scan {scan_id} and device {self.device}")
return data return data
@typechecked @typechecked

View File

@ -0,0 +1,105 @@
from unittest import mock
import numpy as np
import pytest
from bec_server.data_processing.image_analysis_service import (
DAPError,
ImageAnalysisService,
ReturnType,
)
@pytest.fixture
def image_analysis_service():
yield ImageAnalysisService(client=mock.MagicMock())
def test_image_analysis_configure(image_analysis_service):
"""Test the configure method of the image analysis service."""
# Test with scan item
scan_item = mock.MagicMock()
dummy_data = [np.linspace(0, 1, 100) for _ in range(10)]
scan_item.return_value = "mock_scan_id"
with mock.patch.object(
image_analysis_service, "get_images_for_scan_item", return_value=dummy_data
):
image_analysis_service.configure(scan_item=scan_item, device="eiger")
assert image_analysis_service.data == dummy_data
assert image_analysis_service.device == "eiger"
assert image_analysis_service.return_type == ReturnType.CENTER_OF_MASS
# Reset the imageanalysisService
image_analysis_service.data = []
image_analysis_service.device = None
image_analysis_service.return_type = None
# Missing device argument
with pytest.raises(DAPError):
image_analysis_service.configure(scan_item=scan_item)
# Reset the imageanalysisService
image_analysis_service.data = []
image_analysis_service.device = None
image_analysis_service.return_type = None
# Missing scan item
with pytest.raises(DAPError):
image_analysis_service.configure(device="eiger")
# Reset the imageanalysisService
image_analysis_service.data = []
image_analysis_service.device = None
image_analysis_service.return_type = None
# Test with images
image_analysis_service.configure(images=dummy_data)
assert image_analysis_service.data == dummy_data
def test_get_images_for_scan_item(image_analysis_service):
"""Test the get_images_for_scan_item method of the image analysis service."""
dummy_data = [np.linspace(0, 1, 100) for _ in range(10)]
with mock.patch.object(
image_analysis_service.device_monitor_plugin,
"get_data_for_scan",
side_effect=[dummy_data, []],
):
# Test with existing scan id
scan_id = "mock_scan_id"
image_analysis_service.scan_id = scan_id
data = image_analysis_service.get_images_for_scan_item(scan_id)
assert data == dummy_data
assert image_analysis_service.scan_id == scan_id
# Test with empty return
data = image_analysis_service.get_images_for_scan_item(scan_id)
assert data == []
def test_compute_statistics(image_analysis_service):
"""Test the compute_statistics method of the image analysis service."""
dummy_data = [np.zeros((10, 10)), np.ones((10, 10))]
stats = image_analysis_service.compute_statistics(dummy_data)
assert stats["min"].shape == (2,)
assert stats["max"].shape == (2,)
assert np.isclose(stats["mean"], np.array([0, 1])).all()
assert np.isclose(stats["min"], np.array([0, 1])).all()
def test_get_stream_output(image_analysis_service):
"""Test the get_stream_output method of the image analysis service."""
dummy_data = [np.ones((10, 10)), np.ones((10, 10))]
stats = image_analysis_service.compute_statistics(dummy_data)
image_analysis_service.return_type = ReturnType.MIN
stream_output = image_analysis_service._compute_stream_output(stats)
assert np.isclose(stream_output["x"], np.linspace(0, 1, 2)).all()
assert (stream_output["y"] == stats["min"]).all()
# Test center of Mass
image_analysis_service.return_type = ReturnType.CENTER_OF_MASS
stream_output = image_analysis_service._compute_stream_output(stats)
assert (stream_output["x"] == stats["center_of_mass"].T[0]).all()
assert (stream_output["y"] == stats["center_of_mass"].T[1]).all()