173 lines
6.1 KiB
Python
173 lines
6.1 KiB
Python
from unittest import mock
|
|
|
|
import fakeredis
|
|
import h5py
|
|
import numpy as np
|
|
import pytest
|
|
from bec_lib.redis_connector import RedisConnector
|
|
from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase
|
|
from typeguard import TypeCheckError
|
|
|
|
from tomcat_bec.devices.std_daq.std_daq_live_processing import StdDaqLiveProcessing
|
|
|
|
|
|
def fake_redis_server(host, port):
|
|
redis = fakeredis.FakeRedis()
|
|
return redis
|
|
|
|
|
|
@pytest.fixture
|
|
def connected_connector():
|
|
connector = RedisConnector("localhost:1", redis_cls=fake_redis_server) # type: ignore
|
|
connector._redis_conn.flushall()
|
|
try:
|
|
yield connector
|
|
finally:
|
|
connector.shutdown()
|
|
|
|
|
|
class MockPSIDeviceBase(PSIDeviceBase):
|
|
def __init__(self, *args, device_manager=None, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.device_manager = device_manager
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_device(connected_connector):
|
|
device_manager = mock.Mock()
|
|
device_manager.connector = connected_connector
|
|
device = MockPSIDeviceBase(name="mock_device", device_manager=device_manager)
|
|
yield device
|
|
|
|
|
|
@pytest.fixture
|
|
def std_daq_live_processing(mock_device):
|
|
signal = mock.Mock()
|
|
signal2 = mock.Mock()
|
|
live_processing = StdDaqLiveProcessing(mock_device, signal, signal2)
|
|
yield live_processing
|
|
|
|
|
|
def test_std_daq_live_processing_set_mode(std_daq_live_processing):
|
|
|
|
std_daq_live_processing.set_mode("sum")
|
|
assert std_daq_live_processing.get_mode() == "sum"
|
|
with pytest.raises(TypeCheckError):
|
|
std_daq_live_processing.set_mode("mode_that_does_not_exist")
|
|
with pytest.raises(TypeCheckError):
|
|
std_daq_live_processing.set_mode(123)
|
|
|
|
|
|
@pytest.fixture(params=["flat", "dark"])
|
|
def reference_type(request):
|
|
return request.param
|
|
|
|
|
|
def test_std_daq_live_processing_flat_default(std_daq_live_processing, reference_type):
|
|
with mock.patch.object(
|
|
std_daq_live_processing, "_get_from_redis", return_value=None
|
|
) as mock_get_from_redis:
|
|
get_method = (
|
|
std_daq_live_processing.get_flat
|
|
if reference_type == "flat"
|
|
else std_daq_live_processing.get_dark
|
|
)
|
|
out = get_method((100, 100))
|
|
mock_get_from_redis.assert_called_once_with(
|
|
std_daq_live_processing._redis_endpoint_name(ref_type=reference_type, shape=(100, 100))
|
|
)
|
|
assert isinstance(out, np.ndarray)
|
|
assert out.shape == (100, 100)
|
|
if reference_type == "flat":
|
|
assert np.all(out == 1), "Default should be all ones"
|
|
else:
|
|
assert np.all(out == 0), "Default should be all zeros"
|
|
|
|
|
|
@pytest.mark.parametrize("value", [np.random.rand(100, 100), np.random.rand(3, 100, 100)])
|
|
def test_std_daq_live_processing_fetch(tmp_path, std_daq_live_processing, value, reference_type):
|
|
|
|
with h5py.File(tmp_path / "test_data.h5", "w") as f:
|
|
f.create_dataset("tomcat-pco/data", data=value)
|
|
|
|
status = std_daq_live_processing.update_reference_with_file(
|
|
reference_type, tmp_path / "test_data.h5", "tomcat-pco/data"
|
|
)
|
|
status.wait()
|
|
|
|
get_method = (
|
|
std_daq_live_processing.get_flat
|
|
if reference_type == "flat"
|
|
else std_daq_live_processing.get_dark
|
|
)
|
|
|
|
out = get_method((100, 100))
|
|
assert isinstance(out, np.ndarray)
|
|
assert out.shape == (100, 100)
|
|
|
|
# Check that the data is cached locally
|
|
assert np.array_equal(
|
|
std_daq_live_processing.references[f"{reference_type}_(100, 100)"], out
|
|
), "Cached flat data should match fetched data"
|
|
|
|
redis_data = std_daq_live_processing._get_from_redis(
|
|
std_daq_live_processing._redis_endpoint_name(ref_type=reference_type, shape=(100, 100))
|
|
)
|
|
assert isinstance(redis_data, np.ndarray)
|
|
assert redis_data.shape == (100, 100)
|
|
assert np.array_equal(redis_data, out), "Redis data should match the locally cached data"
|
|
|
|
|
|
def test_std_daq_live_processing_apply_flat_dark_correction(std_daq_live_processing):
|
|
# Create a mock image
|
|
image = np.random.rand(100, 100)
|
|
|
|
# Set flat and dark references
|
|
std_daq_live_processing.references["flat_(100, 100)"] = np.ones((100, 100))
|
|
std_daq_live_processing.references["dark_(100, 100)"] = np.zeros((100, 100))
|
|
|
|
# Apply flat and dark correction
|
|
corrected_image = std_daq_live_processing.apply_flat_dark_correction(image)
|
|
assert isinstance(corrected_image, np.ndarray)
|
|
assert corrected_image.shape == (100, 100)
|
|
assert np.all(corrected_image >= 0), "Corrected image should not have negative values"
|
|
|
|
|
|
def test_std_daq_live_processing_apply_flat_dark_correction_with_dark(std_daq_live_processing):
|
|
# Create a mock image
|
|
image = np.random.rand(100, 100) * 1000 # Scale to simulate a realistic image
|
|
|
|
dark = np.random.rand(100, 100) * 100 # Simulate a dark reference
|
|
image += dark # Add dark to the image to simulate a realistic scenario
|
|
|
|
# Set flat and dark references
|
|
std_daq_live_processing.references["flat_(100, 100)"] = np.ones((100, 100))
|
|
std_daq_live_processing.references["dark_(100, 100)"] = dark
|
|
# Apply flat and dark correction
|
|
corrected_image = std_daq_live_processing.apply_flat_dark_correction(image)
|
|
assert isinstance(corrected_image, np.ndarray)
|
|
assert corrected_image.shape == (100, 100)
|
|
assert np.all(corrected_image >= 0), "Corrected image should not have negative values"
|
|
|
|
|
|
def test_std_daq_live_processing_apply_flat_correction_zero_division(std_daq_live_processing):
|
|
|
|
# Create a mock image
|
|
image = np.random.rand(100, 100) * 1000 + 10 # Scale to simulate a realistic image
|
|
|
|
# Set flat reference with epsilon values
|
|
flat = np.ones((100, 100)) * 2
|
|
std_daq_live_processing.references["flat_(100, 100)"] = flat
|
|
|
|
# Set dark reference to ones
|
|
dark = np.ones((100, 100)) * 2
|
|
|
|
std_daq_live_processing.references["dark_(100, 100)"] = dark
|
|
|
|
# Apply flat correction
|
|
corrected_image = std_daq_live_processing.apply_flat_dark_correction(image)
|
|
assert isinstance(corrected_image, np.ndarray)
|
|
assert corrected_image.shape == (100, 100)
|
|
assert np.all(corrected_image >= 0), "Corrected image should not have negative values"
|
|
assert np.any(corrected_image < np.inf), "Corrected image should not have infinite values"
|