Files
tomcat_bec/tests/tests_devices/test_std_daq_live_processing.py

173 lines
6.1 KiB
Python

from unittest import mock
import fakeredis
import h5py
import numpy as np
import pytest
from bec_lib.redis_connector import RedisConnector
from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase
from typeguard import TypeCheckError
from tomcat_bec.devices.std_daq.std_daq_live_processing import StdDaqLiveProcessing
def fake_redis_server(host, port):
redis = fakeredis.FakeRedis()
return redis
@pytest.fixture
def connected_connector():
connector = RedisConnector("localhost:1", redis_cls=fake_redis_server) # type: ignore
connector._redis_conn.flushall()
try:
yield connector
finally:
connector.shutdown()
class MockPSIDeviceBase(PSIDeviceBase):
def __init__(self, *args, device_manager=None, **kwargs):
super().__init__(*args, **kwargs)
self.device_manager = device_manager
@pytest.fixture
def mock_device(connected_connector):
device_manager = mock.Mock()
device_manager.connector = connected_connector
device = MockPSIDeviceBase(name="mock_device", device_manager=device_manager)
yield device
@pytest.fixture
def std_daq_live_processing(mock_device):
signal = mock.Mock()
signal2 = mock.Mock()
live_processing = StdDaqLiveProcessing(mock_device, signal, signal2)
yield live_processing
def test_std_daq_live_processing_set_mode(std_daq_live_processing):
std_daq_live_processing.set_mode("sum")
assert std_daq_live_processing.get_mode() == "sum"
with pytest.raises(TypeCheckError):
std_daq_live_processing.set_mode("mode_that_does_not_exist")
with pytest.raises(TypeCheckError):
std_daq_live_processing.set_mode(123)
@pytest.fixture(params=["flat", "dark"])
def reference_type(request):
return request.param
def test_std_daq_live_processing_flat_default(std_daq_live_processing, reference_type):
with mock.patch.object(
std_daq_live_processing, "_get_from_redis", return_value=None
) as mock_get_from_redis:
get_method = (
std_daq_live_processing.get_flat
if reference_type == "flat"
else std_daq_live_processing.get_dark
)
out = get_method((100, 100))
mock_get_from_redis.assert_called_once_with(
std_daq_live_processing._redis_endpoint_name(ref_type=reference_type, shape=(100, 100))
)
assert isinstance(out, np.ndarray)
assert out.shape == (100, 100)
if reference_type == "flat":
assert np.all(out == 1), "Default should be all ones"
else:
assert np.all(out == 0), "Default should be all zeros"
@pytest.mark.parametrize("value", [np.random.rand(100, 100), np.random.rand(3, 100, 100)])
def test_std_daq_live_processing_fetch(tmp_path, std_daq_live_processing, value, reference_type):
with h5py.File(tmp_path / "test_data.h5", "w") as f:
f.create_dataset("tomcat-pco/data", data=value)
status = std_daq_live_processing.update_reference_with_file(
reference_type, tmp_path / "test_data.h5", "tomcat-pco/data"
)
status.wait()
get_method = (
std_daq_live_processing.get_flat
if reference_type == "flat"
else std_daq_live_processing.get_dark
)
out = get_method((100, 100))
assert isinstance(out, np.ndarray)
assert out.shape == (100, 100)
# Check that the data is cached locally
assert np.array_equal(
std_daq_live_processing.references[f"{reference_type}_(100, 100)"], out
), "Cached flat data should match fetched data"
redis_data = std_daq_live_processing._get_from_redis(
std_daq_live_processing._redis_endpoint_name(ref_type=reference_type, shape=(100, 100))
)
assert isinstance(redis_data, np.ndarray)
assert redis_data.shape == (100, 100)
assert np.array_equal(redis_data, out), "Redis data should match the locally cached data"
def test_std_daq_live_processing_apply_flat_dark_correction(std_daq_live_processing):
# Create a mock image
image = np.random.rand(100, 100)
# Set flat and dark references
std_daq_live_processing.references["flat_(100, 100)"] = np.ones((100, 100))
std_daq_live_processing.references["dark_(100, 100)"] = np.zeros((100, 100))
# Apply flat and dark correction
corrected_image = std_daq_live_processing.apply_flat_dark_correction(image)
assert isinstance(corrected_image, np.ndarray)
assert corrected_image.shape == (100, 100)
assert np.all(corrected_image >= 0), "Corrected image should not have negative values"
def test_std_daq_live_processing_apply_flat_dark_correction_with_dark(std_daq_live_processing):
# Create a mock image
image = np.random.rand(100, 100) * 1000 # Scale to simulate a realistic image
dark = np.random.rand(100, 100) * 100 # Simulate a dark reference
image += dark # Add dark to the image to simulate a realistic scenario
# Set flat and dark references
std_daq_live_processing.references["flat_(100, 100)"] = np.ones((100, 100))
std_daq_live_processing.references["dark_(100, 100)"] = dark
# Apply flat and dark correction
corrected_image = std_daq_live_processing.apply_flat_dark_correction(image)
assert isinstance(corrected_image, np.ndarray)
assert corrected_image.shape == (100, 100)
assert np.all(corrected_image >= 0), "Corrected image should not have negative values"
def test_std_daq_live_processing_apply_flat_correction_zero_division(std_daq_live_processing):
# Create a mock image
image = np.random.rand(100, 100) * 1000 + 10 # Scale to simulate a realistic image
# Set flat reference with epsilon values
flat = np.ones((100, 100)) * 2
std_daq_live_processing.references["flat_(100, 100)"] = flat
# Set dark reference to ones
dark = np.ones((100, 100)) * 2
std_daq_live_processing.references["dark_(100, 100)"] = dark
# Apply flat correction
corrected_image = std_daq_live_processing.apply_flat_dark_correction(image)
assert isinstance(corrected_image, np.ndarray)
assert corrected_image.shape == (100, 100)
assert np.all(corrected_image >= 0), "Corrected image should not have negative values"
assert np.any(corrected_image < np.inf), "Corrected image should not have infinite values"