test: cleanup, refactor and add tests

This commit is contained in:
appel_c 2024-12-06 14:11:31 +01:00
parent 1c225aabfe
commit c48a97d9bc
3 changed files with 158 additions and 21 deletions

View File

@ -502,3 +502,30 @@ def test_interactive_scan(bec_client_lib):
while len(report.scan.live_data) != 10:
time.sleep(0.1)
assert len(report.scan.live_data.samx.samx.val) == 10
def test_image_analysis(bec_client_lib):
bec = bec_client_lib
bec.metadata.update({"unit_test": "test_image_analysis"})
dev = bec.device_manager.devices
scans = bec.scans
dev.eiger.sim.select_model("gaussian")
dev.eiger.sim.params = {
"amplitude": 100,
"center_offset": np.array([0, 0]),
"covariance": np.array([[1, 0], [0, 1]]),
"noise": "uniform",
"noise_multiplier": 10,
"hot_pixel_coords": np.array([[24, 24], [50, 20], [4, 40]]),
"hot_pixel_types": ["fluctuating", "constant", "fluctuating"],
"hot_pixel_values": np.array([1000.0, 10000.0, 1000.0]),
}
res = scans.line_scan(dev.samx, -5, 5, steps=10, relative=False, exp_time=0)
res.wait()
fit_res = bec.dap.image_analysis.run(res.scan.scan_id, "eiger")
assert (fit_res[1]["stats"]["max"] == 10000.0).all()
assert (fit_res[1]["stats"]["min"] == 0.0).all()
assert (np.isclose(fit_res[1]["stats"]["mean"], 3.3, atol=0.5)).all()
# Center of mass is not in the middle due to hot (fluctuating) pixels
assert (np.isclose(fit_res[1]["stats"]["center_of_mass"], [49.5, 40.8], atol=1)).all()

View File

@ -8,6 +8,7 @@ from scipy.ndimage import center_of_mass
from typeguard import typechecked
from bec_lib.device_monitor_plugin import DeviceMonitorPlugin
from bec_lib.logger import bec_logger
from bec_server.data_processing.dap_service import DAPError, DAPServiceBase
if TYPE_CHECKING:
@ -15,6 +16,8 @@ if TYPE_CHECKING:
from bec_lib.device import DeviceBase
from bec_lib.scan_items import ScanItem
logger = bec_logger.logger
class ReturnType(str, Enum):
"""The possible return data types for the image analysis service."""
@ -37,15 +40,14 @@ class ImageAnalysisService(DAPServiceBase):
self.device = None
self.return_type = None
def configure(
self, *args, scan_item=None, device=None, images=None, return_type=None, **kwargs
):
# TODO somehow the serialisation of these arguments was noch properly working.
# It crashed on the ScanItem/DeviceBase/np.ndarray types. (I believe)
def configure(self, scan_item=None, device=None, images=None, return_type=None, **kwargs):
# TODO Add type hints for np.ndarray and list[np.ndarray] do not work yet in the signature_serializer
# This will be adressed in a different MR, issue is created #395
# scan_item: ScanItem | str = None,
# device: DeviceBase | str = None,
# images: np.ndarray | list[np.ndarray] | None = None,
# **kwargs,
# ):
"""Configure the image analysis service. Either provide a scan item and a device which
has a 2D monitor active, or provide the images directly. If no data is found for the input
the service will return an empty stream output.
@ -56,31 +58,34 @@ class ImageAnalysisService(DAPServiceBase):
images: Alternatively, you can provide the images directly
return_type: The type of data to return, can be "min", "max", "mean", "median", "std", "center_of_mass"
"""
self.device = str(device)
if return_type is None:
return_type = ReturnType.CENTER_OF_MASS
else:
return_type = ReturnType(return_type)
self.return_type = return_type
if images is None:
self.data = self.get_images_for_scan_item(scan_item=scan_item)
# If images are provided, use them
if images is not None:
if isinstance(images, np.ndarray):
self.data = [images]
elif isinstance(images, list) and all(
isinstance(image, np.ndarray) for image in images
):
self.data = images
return
if isinstance(images, np.ndarray):
self.data = [images]
elif isinstance(images, list) and all(isinstance(image, np.ndarray) for image in images):
self.data = images
else:
raise DAPError(f"Invalid format for images: {images} provided")
# Else if scan item is provided, get the images
if device is None or scan_item is None:
raise DAPError(
f"Either provide a device: {device} and scan_id {scan_item} or images {images}"
)
self.device = str(device)
self.data = self.get_images_for_scan_item(scan_id=scan_item)
def get_images_for_scan_item(self, scan_item: ScanItem | str) -> list[np.ndarray]:
def get_images_for_scan_item(self, scan_id: str) -> list[np.ndarray]:
"""Get the data for the scan item."""
scan_id = scan_item
if scan_id != self.scan_id or not self.current_scan_item:
scan_item = self.client.queue.scan_storage.find_scan_by_ID(scan_id)
self.scan_id = scan_id
else:
scan_item = self.current_scan_item
self.scan_id = scan_id
data = self.device_monitor_plugin.get_data_for_scan(device=self.device, scan=self.scan_id)
if len(data) == 0:
logger.warning(f"No data found for scan {scan_id} and device {self.device}")
return data
@typechecked

View File

@ -0,0 +1,105 @@
from unittest import mock
import numpy as np
import pytest
from bec_server.data_processing.image_analysis_service import (
DAPError,
ImageAnalysisService,
ReturnType,
)
@pytest.fixture
def image_analysis_service():
yield ImageAnalysisService(client=mock.MagicMock())
def test_image_analysis_configure(image_analysis_service):
"""Test the configure method of the image analysis service."""
# Test with scan item
scan_item = mock.MagicMock()
dummy_data = [np.linspace(0, 1, 100) for _ in range(10)]
scan_item.return_value = "mock_scan_id"
with mock.patch.object(
image_analysis_service, "get_images_for_scan_item", return_value=dummy_data
):
image_analysis_service.configure(scan_item=scan_item, device="eiger")
assert image_analysis_service.data == dummy_data
assert image_analysis_service.device == "eiger"
assert image_analysis_service.return_type == ReturnType.CENTER_OF_MASS
# Reset the imageanalysisService
image_analysis_service.data = []
image_analysis_service.device = None
image_analysis_service.return_type = None
# Missing device argument
with pytest.raises(DAPError):
image_analysis_service.configure(scan_item=scan_item)
# Reset the imageanalysisService
image_analysis_service.data = []
image_analysis_service.device = None
image_analysis_service.return_type = None
# Missing scan item
with pytest.raises(DAPError):
image_analysis_service.configure(device="eiger")
# Reset the imageanalysisService
image_analysis_service.data = []
image_analysis_service.device = None
image_analysis_service.return_type = None
# Test with images
image_analysis_service.configure(images=dummy_data)
assert image_analysis_service.data == dummy_data
def test_get_images_for_scan_item(image_analysis_service):
"""Test the get_images_for_scan_item method of the image analysis service."""
dummy_data = [np.linspace(0, 1, 100) for _ in range(10)]
with mock.patch.object(
image_analysis_service.device_monitor_plugin,
"get_data_for_scan",
side_effect=[dummy_data, []],
):
# Test with existing scan id
scan_id = "mock_scan_id"
image_analysis_service.scan_id = scan_id
data = image_analysis_service.get_images_for_scan_item(scan_id)
assert data == dummy_data
assert image_analysis_service.scan_id == scan_id
# Test with empty return
data = image_analysis_service.get_images_for_scan_item(scan_id)
assert data == []
def test_compute_statistics(image_analysis_service):
"""Test the compute_statistics method of the image analysis service."""
dummy_data = [np.zeros((10, 10)), np.ones((10, 10))]
stats = image_analysis_service.compute_statistics(dummy_data)
assert stats["min"].shape == (2,)
assert stats["max"].shape == (2,)
assert np.isclose(stats["mean"], np.array([0, 1])).all()
assert np.isclose(stats["min"], np.array([0, 1])).all()
def test_get_stream_output(image_analysis_service):
"""Test the get_stream_output method of the image analysis service."""
dummy_data = [np.ones((10, 10)), np.ones((10, 10))]
stats = image_analysis_service.compute_statistics(dummy_data)
image_analysis_service.return_type = ReturnType.MIN
stream_output = image_analysis_service._compute_stream_output(stats)
assert np.isclose(stream_output["x"], np.linspace(0, 1, 2)).all()
assert (stream_output["y"] == stats["min"]).all()
# Test center of Mass
image_analysis_service.return_type = ReturnType.CENTER_OF_MASS
stream_output = image_analysis_service._compute_stream_output(stats)
assert (stream_output["x"] == stats["center_of_mass"].T[0]).all()
assert (stream_output["y"] == stats["center_of_mass"].T[1]).all()