from unittest import mock import fakeredis import h5py import numpy as np import pytest from bec_lib.redis_connector import RedisConnector from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase from typeguard import TypeCheckError from tomcat_bec.devices.std_daq.std_daq_live_processing import StdDaqLiveProcessing def fake_redis_server(host, port, **kwargs): redis = fakeredis.FakeRedis() return redis @pytest.fixture def connected_connector(): connector = RedisConnector("localhost:1", redis_cls=fake_redis_server) # type: ignore connector._redis_conn.flushall() try: yield connector finally: connector.shutdown() class MockPSIDeviceBase(PSIDeviceBase): def __init__(self, *args, device_manager=None, **kwargs): super().__init__(*args, **kwargs) self.device_manager = device_manager @pytest.fixture def mock_device(connected_connector): device_manager = mock.Mock() device_manager.connector = connected_connector device = MockPSIDeviceBase(name="mock_device", device_manager=device_manager) yield device @pytest.fixture def std_daq_live_processing(mock_device): signal = mock.Mock() signal2 = mock.Mock() live_processing = StdDaqLiveProcessing(mock_device, signal, signal2) yield live_processing def test_std_daq_live_processing_set_mode(std_daq_live_processing): std_daq_live_processing.set_mode("sum") assert std_daq_live_processing.get_mode() == "sum" with pytest.raises(TypeCheckError): std_daq_live_processing.set_mode("average") with pytest.raises(TypeCheckError): std_daq_live_processing.set_mode(123) @pytest.fixture(params=["flat", "dark"]) def reference_type(request): return request.param def test_std_daq_live_processing_flat_default(std_daq_live_processing, reference_type): with mock.patch.object( std_daq_live_processing, "_get_from_redis", return_value=None ) as mock_get_from_redis: get_method = ( std_daq_live_processing.get_flat if reference_type == "flat" else std_daq_live_processing.get_dark ) out = get_method((100, 100)) mock_get_from_redis.assert_called_once_with( std_daq_live_processing._redis_endpoint_name(ref_type=reference_type, shape=(100, 100)) ) assert isinstance(out, np.ndarray) assert out.shape == (100, 100) if reference_type == "flat": assert np.all(out == 1), "Default should be all ones" else: assert np.all(out == 0), "Default should be all zeros" @pytest.mark.parametrize("value", [np.random.rand(100, 100), np.random.rand(3, 100, 100)]) def test_std_daq_live_processing_fetch(tmp_path, std_daq_live_processing, value, reference_type): with h5py.File(tmp_path / "test_data.h5", "w") as f: f.create_dataset("tomcat-pco/data", data=value) status = std_daq_live_processing.update_reference_with_file( reference_type, tmp_path / "test_data.h5", "tomcat-pco/data" ) status.wait() get_method = ( std_daq_live_processing.get_flat if reference_type == "flat" else std_daq_live_processing.get_dark ) out = get_method((100, 100)) assert isinstance(out, np.ndarray) assert out.shape == (100, 100) # Check that the data is cached locally assert np.array_equal( std_daq_live_processing.references[f"{reference_type}_(100, 100)"], out ), "Cached flat data should match fetched data" redis_data = std_daq_live_processing._get_from_redis( std_daq_live_processing._redis_endpoint_name(ref_type=reference_type, shape=(100, 100)) ) assert isinstance(redis_data, np.ndarray) assert redis_data.shape == (100, 100) assert np.array_equal(redis_data, out), "Redis data should match the locally cached data" def test_std_daq_live_processing_apply_flat_dark_correction(std_daq_live_processing): # Create a mock image image = np.random.rand(100, 100) # Set flat and dark references std_daq_live_processing.references["flat_(100, 100)"] = np.ones((100, 100)) std_daq_live_processing.references["dark_(100, 100)"] = np.zeros((100, 100)) # Apply flat and dark correction corrected_image = std_daq_live_processing.apply_flat_dark_correction(image) assert isinstance(corrected_image, np.ndarray) assert corrected_image.shape == (100, 100) assert np.all(corrected_image >= 0), "Corrected image should not have negative values" def test_std_daq_live_processing_apply_flat_dark_correction_with_dark(std_daq_live_processing): # Create a mock image image = np.random.rand(100, 100) * 1000 # Scale to simulate a realistic image dark = np.random.rand(100, 100) * 100 # Simulate a dark reference image += dark # Add dark to the image to simulate a realistic scenario # Set flat and dark references std_daq_live_processing.references["flat_(100, 100)"] = np.ones((100, 100)) std_daq_live_processing.references["dark_(100, 100)"] = dark # Apply flat and dark correction corrected_image = std_daq_live_processing.apply_flat_dark_correction(image) assert isinstance(corrected_image, np.ndarray) assert corrected_image.shape == (100, 100) assert np.all(corrected_image >= 0), "Corrected image should not have negative values" def test_std_daq_live_processing_apply_flat_correction_zero_division(std_daq_live_processing): # Create a mock image image = np.random.rand(100, 100) * 1000 + 10 # Scale to simulate a realistic image # Set flat reference with epsilon values flat = np.ones((100, 100)) * 2 std_daq_live_processing.references["flat_(100, 100)"] = flat # Set dark reference to ones dark = np.ones((100, 100)) * 2 std_daq_live_processing.references["dark_(100, 100)"] = dark # Apply flat correction corrected_image = std_daq_live_processing.apply_flat_dark_correction(image) assert isinstance(corrected_image, np.ndarray) assert corrected_image.shape == (100, 100) assert np.all(corrected_image >= 0), "Corrected image should not have negative values" assert np.any(corrected_image < np.inf), "Corrected image should not have infinite values"