Major refactor of the std daq integration for the PCO Edge camera and Gigafrost camera. New live processing capabilities have been added, and the code has been cleaned up for better maintainability.
151 lines
5.2 KiB
Python
151 lines
5.2 KiB
Python
from unittest import mock
|
|
|
|
import fakeredis
|
|
import h5py
|
|
import numpy as np
|
|
import pytest
|
|
from bec_lib.redis_connector import RedisConnector
|
|
from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase
|
|
from typeguard import TypeCheckError
|
|
|
|
from tomcat_bec.devices.std_daq.std_daq_live_processing import StdDaqLiveProcessing
|
|
|
|
|
|
def fake_redis_server(host, port):
|
|
redis = fakeredis.FakeRedis()
|
|
return redis
|
|
|
|
|
|
@pytest.fixture
|
|
def connected_connector():
|
|
connector = RedisConnector("localhost:1", redis_cls=fake_redis_server) # type: ignore
|
|
connector._redis_conn.flushall()
|
|
try:
|
|
yield connector
|
|
finally:
|
|
connector.shutdown()
|
|
|
|
|
|
class MockPSIDeviceBase(PSIDeviceBase):
|
|
def __init__(self, *args, device_manager=None, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.device_manager = device_manager
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_device(connected_connector):
|
|
device_manager = mock.Mock()
|
|
device_manager.connector = connected_connector
|
|
device = MockPSIDeviceBase(name="mock_device", device_manager=device_manager)
|
|
yield device
|
|
|
|
|
|
@pytest.fixture
|
|
def std_daq_live_processing(mock_device):
|
|
signal = mock.Mock()
|
|
signal2 = mock.Mock()
|
|
live_processing = StdDaqLiveProcessing(mock_device, signal, signal2)
|
|
yield live_processing
|
|
|
|
|
|
def test_std_daq_live_processing_set_mode(std_daq_live_processing):
|
|
|
|
std_daq_live_processing.set_mode("sum")
|
|
assert std_daq_live_processing.get_mode() == "sum"
|
|
with pytest.raises(TypeCheckError):
|
|
std_daq_live_processing.set_mode("average")
|
|
with pytest.raises(TypeCheckError):
|
|
std_daq_live_processing.set_mode(123)
|
|
|
|
|
|
@pytest.fixture(params=["flat", "dark"])
|
|
def reference_type(request):
|
|
return request.param
|
|
|
|
|
|
def test_std_daq_live_processing_flat_default(std_daq_live_processing, reference_type):
|
|
with mock.patch.object(
|
|
std_daq_live_processing, "_get_from_redis", return_value=None
|
|
) as mock_get_from_redis:
|
|
get_method = (
|
|
std_daq_live_processing.get_flat
|
|
if reference_type == "flat"
|
|
else std_daq_live_processing.get_dark
|
|
)
|
|
out = get_method((100, 100))
|
|
mock_get_from_redis.assert_called_once_with(
|
|
std_daq_live_processing._redis_endpoint_name(ref_type=reference_type, shape=(100, 100))
|
|
)
|
|
assert isinstance(out, np.ndarray)
|
|
assert out.shape == (100, 100)
|
|
if reference_type == "flat":
|
|
assert np.all(out == 1), "Default should be all ones"
|
|
else:
|
|
assert np.all(out == 0), "Default should be all zeros"
|
|
|
|
|
|
@pytest.mark.parametrize("value", [np.random.rand(100, 100), np.random.rand(3, 100, 100)])
|
|
def test_std_daq_live_processing_fetch(tmp_path, std_daq_live_processing, value, reference_type):
|
|
|
|
with h5py.File(tmp_path / "test_data.h5", "w") as f:
|
|
f.create_dataset("tomcat-pco/data", data=value)
|
|
|
|
status = std_daq_live_processing.update_reference_with_file(
|
|
reference_type, tmp_path / "test_data.h5", "tomcat-pco/data"
|
|
)
|
|
status.wait()
|
|
|
|
get_method = (
|
|
std_daq_live_processing.get_flat
|
|
if reference_type == "flat"
|
|
else std_daq_live_processing.get_dark
|
|
)
|
|
|
|
out = get_method((100, 100))
|
|
assert isinstance(out, np.ndarray)
|
|
assert out.shape == (100, 100)
|
|
|
|
# Check that the data is cached locally
|
|
assert np.array_equal(
|
|
std_daq_live_processing.references[f"{reference_type}_(100, 100)"], out
|
|
), "Cached flat data should match fetched data"
|
|
|
|
redis_data = std_daq_live_processing._get_from_redis(
|
|
std_daq_live_processing._redis_endpoint_name(ref_type=reference_type, shape=(100, 100))
|
|
)
|
|
assert isinstance(redis_data, np.ndarray)
|
|
assert redis_data.shape == (100, 100)
|
|
assert np.array_equal(redis_data, out), "Redis data should match the locally cached data"
|
|
|
|
|
|
def test_std_daq_live_processing_apply_flat_dark_correction(std_daq_live_processing):
|
|
# Create a mock image
|
|
image = np.random.rand(100, 100)
|
|
|
|
# Set flat and dark references
|
|
std_daq_live_processing.references["flat_(100, 100)"] = np.ones((100, 100))
|
|
std_daq_live_processing.references["dark_(100, 100)"] = np.zeros((100, 100))
|
|
|
|
# Apply flat and dark correction
|
|
corrected_image = std_daq_live_processing.apply_flat_dark_correction(image)
|
|
assert isinstance(corrected_image, np.ndarray)
|
|
assert corrected_image.shape == (100, 100)
|
|
assert np.all(corrected_image >= 0), "Corrected image should not have negative values"
|
|
|
|
|
|
def test_std_daq_live_processing_apply_flat_dark_correction_with_dark(std_daq_live_processing):
|
|
# Create a mock image
|
|
image = np.random.rand(100, 100) * 1000 # Scale to simulate a realistic image
|
|
|
|
dark = np.random.rand(100, 100) * 100 # Simulate a dark reference
|
|
image += dark # Add dark to the image to simulate a realistic scenario
|
|
|
|
# Set flat and dark references
|
|
std_daq_live_processing.references["flat_(100, 100)"] = np.ones((100, 100))
|
|
std_daq_live_processing.references["dark_(100, 100)"] = dark
|
|
# Apply flat and dark correction
|
|
corrected_image = std_daq_live_processing.apply_flat_dark_correction(image)
|
|
assert isinstance(corrected_image, np.ndarray)
|
|
assert corrected_image.shape == (100, 100)
|
|
assert np.all(corrected_image >= 0), "Corrected image should not have negative values"
|