Files
tomcat_bec/tests/tests_devices/test_std_daq_live_processing.py
gac-x05la ae85d179f5 refactor: std daq integration
Major refactor of the std daq integration for the PCO Edge camera and Gigafrost camera.
New live processing capabilities have been added, and the code has been cleaned up for better maintainability.
2025-06-16 16:59:08 +02:00

151 lines
5.2 KiB
Python

from unittest import mock
import fakeredis
import h5py
import numpy as np
import pytest
from bec_lib.redis_connector import RedisConnector
from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase
from typeguard import TypeCheckError
from tomcat_bec.devices.std_daq.std_daq_live_processing import StdDaqLiveProcessing
def fake_redis_server(host, port):
redis = fakeredis.FakeRedis()
return redis
@pytest.fixture
def connected_connector():
connector = RedisConnector("localhost:1", redis_cls=fake_redis_server) # type: ignore
connector._redis_conn.flushall()
try:
yield connector
finally:
connector.shutdown()
class MockPSIDeviceBase(PSIDeviceBase):
def __init__(self, *args, device_manager=None, **kwargs):
super().__init__(*args, **kwargs)
self.device_manager = device_manager
@pytest.fixture
def mock_device(connected_connector):
device_manager = mock.Mock()
device_manager.connector = connected_connector
device = MockPSIDeviceBase(name="mock_device", device_manager=device_manager)
yield device
@pytest.fixture
def std_daq_live_processing(mock_device):
signal = mock.Mock()
signal2 = mock.Mock()
live_processing = StdDaqLiveProcessing(mock_device, signal, signal2)
yield live_processing
def test_std_daq_live_processing_set_mode(std_daq_live_processing):
std_daq_live_processing.set_mode("sum")
assert std_daq_live_processing.get_mode() == "sum"
with pytest.raises(TypeCheckError):
std_daq_live_processing.set_mode("average")
with pytest.raises(TypeCheckError):
std_daq_live_processing.set_mode(123)
@pytest.fixture(params=["flat", "dark"])
def reference_type(request):
return request.param
def test_std_daq_live_processing_flat_default(std_daq_live_processing, reference_type):
with mock.patch.object(
std_daq_live_processing, "_get_from_redis", return_value=None
) as mock_get_from_redis:
get_method = (
std_daq_live_processing.get_flat
if reference_type == "flat"
else std_daq_live_processing.get_dark
)
out = get_method((100, 100))
mock_get_from_redis.assert_called_once_with(
std_daq_live_processing._redis_endpoint_name(ref_type=reference_type, shape=(100, 100))
)
assert isinstance(out, np.ndarray)
assert out.shape == (100, 100)
if reference_type == "flat":
assert np.all(out == 1), "Default should be all ones"
else:
assert np.all(out == 0), "Default should be all zeros"
@pytest.mark.parametrize("value", [np.random.rand(100, 100), np.random.rand(3, 100, 100)])
def test_std_daq_live_processing_fetch(tmp_path, std_daq_live_processing, value, reference_type):
with h5py.File(tmp_path / "test_data.h5", "w") as f:
f.create_dataset("tomcat-pco/data", data=value)
status = std_daq_live_processing.update_reference_with_file(
reference_type, tmp_path / "test_data.h5", "tomcat-pco/data"
)
status.wait()
get_method = (
std_daq_live_processing.get_flat
if reference_type == "flat"
else std_daq_live_processing.get_dark
)
out = get_method((100, 100))
assert isinstance(out, np.ndarray)
assert out.shape == (100, 100)
# Check that the data is cached locally
assert np.array_equal(
std_daq_live_processing.references[f"{reference_type}_(100, 100)"], out
), "Cached flat data should match fetched data"
redis_data = std_daq_live_processing._get_from_redis(
std_daq_live_processing._redis_endpoint_name(ref_type=reference_type, shape=(100, 100))
)
assert isinstance(redis_data, np.ndarray)
assert redis_data.shape == (100, 100)
assert np.array_equal(redis_data, out), "Redis data should match the locally cached data"
def test_std_daq_live_processing_apply_flat_dark_correction(std_daq_live_processing):
# Create a mock image
image = np.random.rand(100, 100)
# Set flat and dark references
std_daq_live_processing.references["flat_(100, 100)"] = np.ones((100, 100))
std_daq_live_processing.references["dark_(100, 100)"] = np.zeros((100, 100))
# Apply flat and dark correction
corrected_image = std_daq_live_processing.apply_flat_dark_correction(image)
assert isinstance(corrected_image, np.ndarray)
assert corrected_image.shape == (100, 100)
assert np.all(corrected_image >= 0), "Corrected image should not have negative values"
def test_std_daq_live_processing_apply_flat_dark_correction_with_dark(std_daq_live_processing):
# Create a mock image
image = np.random.rand(100, 100) * 1000 # Scale to simulate a realistic image
dark = np.random.rand(100, 100) * 100 # Simulate a dark reference
image += dark # Add dark to the image to simulate a realistic scenario
# Set flat and dark references
std_daq_live_processing.references["flat_(100, 100)"] = np.ones((100, 100))
std_daq_live_processing.references["dark_(100, 100)"] = dark
# Apply flat and dark correction
corrected_image = std_daq_live_processing.apply_flat_dark_correction(image)
assert isinstance(corrected_image, np.ndarray)
assert corrected_image.shape == (100, 100)
assert np.all(corrected_image >= 0), "Corrected image should not have negative values"