test: fix test to mock PV access

This commit is contained in:
appel_c 2023-11-07 13:14:10 +01:00
parent ba01cf7b2d
commit 7e9abdb323
2 changed files with 662 additions and 508 deletions

View File

@ -1,93 +1,22 @@
import pytest import pytest
from unittest import mock from unittest import mock
from ophyd.signal import Signal import ophyd
from bec_lib.core import BECMessage, MessageEndpoints from bec_lib.core import BECMessage, MessageEndpoints
from bec_lib.core.devicemanager import DeviceContainer
from bec_lib.core.tests.utils import ProducerMock
class MockSignal(Signal):
def __init__(self, read_pv, *, string=False, name=None, parent=None, **kwargs):
self.read_pv = read_pv
self._string = bool(string)
super().__init__(name=name, parent=parent, **kwargs)
self._waited_for_connection = False
self._subscriptions = []
def wait_for_connection(self):
self._waited_for_connection = True
def subscribe(self, method, event_type, **kw):
self._subscriptions.append((method, event_type, kw))
def describe_configuration(self):
return {self.name + "_conf": {"source": "SIM:test"}}
def read_configuration(self):
return {self.name + "_conf": {"value": 0}}
with mock.patch("ophyd.EpicsSignal", new=MockSignal), mock.patch(
"ophyd.EpicsSignalRO", new=MockSignal
), mock.patch("ophyd.EpicsSignalWithRBV", new=MockSignal):
from ophyd_devices.epics.devices.eiger9m_csaxs import Eiger9McSAXS from ophyd_devices.epics.devices.eiger9m_csaxs import Eiger9McSAXS
from tests.utils import DMMock, MockPV
# TODO maybe specify here that this DeviceMock is for usage in the DeviceServer
class DeviceMock:
def __init__(self, name: str, value: float = 0.0):
self.name = name
self.read_buffer = value
self._config = {"deviceConfig": {"limits": [-50, 50]}, "userParameter": None}
self._enabled_set = True
self._enabled = True
def read(self):
return {self.name: {"value": self.read_buffer}}
def readback(self):
return self.read_buffer
@property
def enabled_set(self) -> bool:
return self._enabled_set
@enabled_set.setter
def enabled_set(self, val: bool):
self._enabled_set = val
@property
def enabled(self) -> bool:
return self._enabled
@enabled.setter
def enabled(self, val: bool):
self._enabled = val
@property
def user_parameter(self):
return self._config["userParameter"]
@property
def obj(self):
return self
class DMMock: def patch_dual_pvs(device):
"""Mock for DeviceManager for walk in device.walk_signals():
if not hasattr(walk.item, "_read_pv"):
The mocked DeviceManager creates a device containert and a producer. continue
if not hasattr(walk.item, "_write_pv"):
""" continue
if walk.item._read_pv.pvname.endswith("_RBV"):
def __init__(self): walk.item._read_pv = walk.item._write_pv
self.devices = DeviceContainer()
self.producer = ProducerMock()
def add_device(self, name: str, value: float = 0.0):
self.devices[name] = DeviceMock(name, value)
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
@ -96,76 +25,73 @@ def mock_det():
prefix = "X12SA-ES-EIGER9M:" prefix = "X12SA-ES-EIGER9M:"
sim_mode = False sim_mode = False
dm = DMMock() dm = DMMock()
# dm.add_device("mokev", value=12.4)
with mock.patch.object(dm, "producer"): with mock.patch.object(dm, "producer"):
with mock.patch.object( with mock.patch(
Eiger9McSAXS, "_update_service_config"
) as mock_update_service_config, mock.patch(
"ophyd_devices.epics.devices.eiger9m_csaxs.FileWriterMixin" "ophyd_devices.epics.devices.eiger9m_csaxs.FileWriterMixin"
) as filemixin: ) as filemixin, mock.patch(
"ophyd_devices.epics.devices.eiger9m_csaxs.Eiger9McSAXS._update_service_config"
) as mock_service_config:
with mock.patch.object(ophyd, "cl") as mock_cl:
mock_cl.get_pv = MockPV
with mock.patch.object(Eiger9McSAXS, "_init"): with mock.patch.object(Eiger9McSAXS, "_init"):
yield Eiger9McSAXS(name=name, prefix=prefix, device_manager=dm, sim_mode=sim_mode) det = Eiger9McSAXS(
name=name, prefix=prefix, device_manager=dm, sim_mode=sim_mode
)
patch_dual_pvs(det)
yield det
def test_init():
"""Test the _init function:"""
name = "eiger"
prefix = "X12SA-ES-EIGER9M:"
sim_mode = False
dm = DMMock()
with mock.patch.object(dm, "producer"):
with mock.patch(
"ophyd_devices.epics.devices.eiger9m_csaxs.FileWriterMixin"
) as filemixin, mock.patch(
"ophyd_devices.epics.devices.eiger9m_csaxs.Eiger9McSAXS._update_service_config"
) as mock_service_config:
with mock.patch.object(ophyd, "cl") as mock_cl:
mock_cl.get_pv = MockPV
with mock.patch.object(
Eiger9McSAXS, "_default_parameter"
) as mock_default, mock.patch.object(
Eiger9McSAXS, "_init_detector"
) as mock_init_det, mock.patch.object(
Eiger9McSAXS, "_init_filewriter"
) as mock_init_fw:
Eiger9McSAXS(name=name, prefix=prefix, device_manager=dm, sim_mode=sim_mode)
mock_default.assert_called_once()
mock_init_det.assert_called_once()
mock_init_fw.assert_called_once()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"trigger_source, detector_state, sim_mode, scan_status_msg, expected_exception", "trigger_source, detector_state, expected_exception",
[ [
( (
2, 2,
1, 1,
True, True,
BECMessage.ScanStatusMessage(
scanID="1",
status={},
info={
"RID": "mockrid1111",
"queueID": "mockqueueID111",
"scan_number": 1,
"exp_time": 0.012,
"num_points": 500,
"readout_time": 0.003,
"scan_type": "fly",
"num_lines": 0.012,
"frames_per_trigger": 1,
},
),
True,
), ),
( (
2, 2,
0, 0,
False, False,
BECMessage.ScanStatusMessage(
scanID="1",
status={},
info={
"RID": "mockrid1111",
"queueID": "mockqueueID111",
"scan_number": 1,
"exp_time": 0.012,
"num_points": 500,
"readout_time": 0.003,
"scan_type": "fly",
"num_lines": 0.012,
"frames_per_trigger": 1,
},
),
False,
), ),
], ],
) )
# TODO rewrite this one, write test for init_detector, init_filewriter is tested def test_init_detector(
def test_init( mock_det,
trigger_source, trigger_source,
detector_state, detector_state,
sim_mode,
scan_status_msg,
expected_exception, expected_exception,
): ):
"""Test the _init function: """Test the _init function:
This includes testing the functions: This includes testing the functions:
- _set_default_parameter
- _init_detector - _init_detector
- _stop_det - _stop_det
- _set_trigger - _set_trigger
@ -174,36 +100,15 @@ def test_init(
Validation upon setting the correct PVs Validation upon setting the correct PVs
""" """
name = "eiger" mock_det.cam.detector_state._read_pv.mock_data = detector_state
prefix = "X12SA-ES-EIGER9M:"
sim_mode = sim_mode
dm = DMMock()
with mock.patch.object(dm, "producer") as producer, mock.patch.object(
Eiger9McSAXS, "_init_filewriter"
) as mock_init_fw, mock.patch.object(
Eiger9McSAXS, "_update_scaninfo"
) as mock_update_scaninfo, mock.patch.object(
Eiger9McSAXS, "_update_filewriter"
) as mock_update_filewriter, mock.patch.object(
Eiger9McSAXS, "_update_service_config"
) as mock_update_service_config:
mock_det = Eiger9McSAXS(name=name, prefix=prefix, device_manager=dm, sim_mode=sim_mode)
mock_det.cam.detector_state.put(detector_state)
if expected_exception: if expected_exception:
with pytest.raises(Exception): with pytest.raises(Exception):
mock_det._init() mock_det._init_detector()
mock_init_fw.assert_called_once()
else: else:
mock_det._init() # call the method you want to test mock_det._init_detector() # call the method you want to test
assert mock_det.cam.acquire.get() == 0 assert mock_det.cam.acquire.get() == 0
assert mock_det.cam.detector_state.get() == detector_state assert mock_det.cam.detector_state.get() == detector_state
assert mock_det.cam.trigger_mode.get() == trigger_source assert mock_det.cam.trigger_mode.get() == trigger_source
mock_init_fw.assert_called()
mock_update_scaninfo.assert_called_once()
mock_update_filewriter.assert_called_once()
mock_update_service_config.assert_called_once()
assert mock_init_fw.call_count == 2
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -225,376 +130,376 @@ def test_update_readout_time(mock_det, readout_time, expected_value):
assert mock_det.readout_time == expected_value assert mock_det.readout_time == expected_value
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"eacc, exp_url, daq_status, daq_cfg, expected_exception", # "eacc, exp_url, daq_status, daq_cfg, expected_exception",
[ # [
( # (
"e12345", # "e12345",
"http://xbl-daq-29:5000", # "http://xbl-daq-29:5000",
{"state": "READY"}, # {"state": "READY"},
{"writer_user_id": 12543}, # {"writer_user_id": 12543},
False, # False,
), # ),
( # (
"e12345", # "e12345",
"http://xbl-daq-29:5000", # "http://xbl-daq-29:5000",
{"state": "READY"}, # {"state": "READY"},
{"writer_user_id": 15421}, # {"writer_user_id": 15421},
False, # False,
), # ),
( # (
"e12345", # "e12345",
"http://xbl-daq-29:5000", # "http://xbl-daq-29:5000",
{"state": "BUSY"}, # {"state": "BUSY"},
{"writer_user_id": 15421}, # {"writer_user_id": 15421},
True, # True,
), # ),
( # (
"e12345", # "e12345",
"http://xbl-daq-29:5000", # "http://xbl-daq-29:5000",
{"state": "READY"}, # {"state": "READY"},
{"writer_ud": 12345}, # {"writer_ud": 12345},
True, # True,
), # ),
], # ],
) # )
def test_init_filewriter(mock_det, eacc, exp_url, daq_status, daq_cfg, expected_exception): # def test_init_filewriter(mock_det, eacc, exp_url, daq_status, daq_cfg, expected_exception):
"""Test _init_filewriter (std daq in this case) # """Test _init_filewriter (std daq in this case)
This includes testing the functions: # This includes testing the functions:
- _update_service_config # - _update_service_config
Validation upon checking set values in mocked std_daq instance # Validation upon checking set values in mocked std_daq instance
""" # """
with mock.patch("ophyd_devices.epics.devices.eiger9m_csaxs.StdDaqClient") as mock_std_daq: # with mock.patch("ophyd_devices.epics.devices.eiger9m_csaxs.StdDaqClient") as mock_std_daq:
instance = mock_std_daq.return_value # instance = mock_std_daq.return_value
instance.stop_writer.return_value = None # instance.stop_writer.return_value = None
instance.get_status.return_value = daq_status # instance.get_status.return_value = daq_status
instance.get_config.return_value = daq_cfg # instance.get_config.return_value = daq_cfg
mock_det.scaninfo.username = eacc # mock_det.scaninfo.username = eacc
# scaninfo.username.return_value = eacc # # scaninfo.username.return_value = eacc
if expected_exception: # if expected_exception:
with pytest.raises(Exception): # with pytest.raises(Exception):
mock_det._init_filewriter() # mock_det._init_filewriter()
else: # else:
mock_det._init_filewriter() # mock_det._init_filewriter()
assert mock_det.std_rest_server_url == exp_url # assert mock_det.std_rest_server_url == exp_url
instance.stop_writer.assert_called_once() # instance.stop_writer.assert_called_once()
instance.get_status.assert_called() # instance.get_status.assert_called()
instance.set_config.assert_called_once_with(daq_cfg) # instance.set_config.assert_called_once_with(daq_cfg)
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"scaninfo, daq_status, daq_cfg, detector_state, stopped, expected_exception", # "scaninfo, daq_status, daq_cfg, detector_state, stopped, expected_exception",
[ # [
( # (
{ # {
"eacc": "e12345", # "eacc": "e12345",
"num_points": 500, # "num_points": 500,
"frames_per_trigger": 1, # "frames_per_trigger": 1,
"filepath": "test.h5", # "filepath": "test.h5",
"scanID": "123", # "scanID": "123",
"mokev": 12.4, # "mokev": 12.4,
}, # },
{"state": "READY"}, # {"state": "READY"},
{"writer_user_id": 12543}, # {"writer_user_id": 12543},
5, # 5,
False, # False,
False, # False,
), # ),
( # (
{ # {
"eacc": "e12345", # "eacc": "e12345",
"num_points": 500, # "num_points": 500,
"frames_per_trigger": 1, # "frames_per_trigger": 1,
"filepath": "test.h5", # "filepath": "test.h5",
"scanID": "123", # "scanID": "123",
"mokev": 12.4, # "mokev": 12.4,
}, # },
{"state": "BUSY"}, # {"state": "BUSY"},
{"writer_user_id": 15421}, # {"writer_user_id": 15421},
5, # 5,
False, # False,
False, # False,
), # ),
( # (
{ # {
"eacc": "e12345", # "eacc": "e12345",
"num_points": 500, # "num_points": 500,
"frames_per_trigger": 1, # "frames_per_trigger": 1,
"filepath": "test.h5", # "filepath": "test.h5",
"scanID": "123", # "scanID": "123",
"mokev": 18.4, # "mokev": 18.4,
}, # },
{"state": "READY"}, # {"state": "READY"},
{"writer_user_id": 12345}, # {"writer_user_id": 12345},
4, # 4,
False, # False,
True, # True,
), # ),
], # ],
) # )
def test_stage( # def test_stage(
mock_det, # mock_det,
scaninfo, # scaninfo,
daq_status, # daq_status,
daq_cfg, # daq_cfg,
detector_state, # detector_state,
stopped, # stopped,
expected_exception, # expected_exception,
): # ):
with mock.patch.object(mock_det, "std_client") as mock_std_daq, mock.patch.object( # with mock.patch.object(mock_det, "std_client") as mock_std_daq, mock.patch.object(
Eiger9McSAXS, "_publish_file_location" # Eiger9McSAXS, "_publish_file_location"
) as mock_publish_file_location: # ) as mock_publish_file_location:
mock_std_daq.stop_writer.return_value = None # mock_std_daq.stop_writer.return_value = None
mock_std_daq.get_status.return_value = daq_status # mock_std_daq.get_status.return_value = daq_status
mock_std_daq.get_config.return_value = daq_cfg # mock_std_daq.get_config.return_value = daq_cfg
mock_det.scaninfo.num_points = scaninfo["num_points"] # mock_det.scaninfo.num_points = scaninfo["num_points"]
mock_det.scaninfo.frames_per_trigger = scaninfo["frames_per_trigger"] # mock_det.scaninfo.frames_per_trigger = scaninfo["frames_per_trigger"]
mock_det.filewriter.compile_full_filename.return_value = scaninfo["filepath"] # mock_det.filewriter.compile_full_filename.return_value = scaninfo["filepath"]
# TODO consider putting energy as variable in scaninfo # # TODO consider putting energy as variable in scaninfo
mock_det.device_manager.add_device("mokev", value=12.4) # mock_det.device_manager.add_device("mokev", value=12.4)
mock_det.cam.beam_energy.put(scaninfo["mokev"]) # mock_det.cam.beam_energy.put(scaninfo["mokev"])
mock_det._stopped = stopped # mock_det._stopped = stopped
mock_det.cam.detector_state.put(detector_state) # mock_det.cam.detector_state.put(detector_state)
with mock.patch.object(mock_det, "_prep_file_writer") as mock_prep_fw: # with mock.patch.object(mock_det, "_prep_file_writer") as mock_prep_fw:
mock_det.filepath = scaninfo["filepath"] # mock_det.filepath = scaninfo["filepath"]
if expected_exception: # if expected_exception:
with pytest.raises(Exception): # with pytest.raises(Exception):
mock_det.stage() # mock_det.stage()
else: # else:
mock_det.stage() # mock_det.stage()
mock_prep_fw.assert_called_once() # mock_prep_fw.assert_called_once()
# Check _prep_det # # Check _prep_det
assert mock_det.cam.num_images.get() == int( # assert mock_det.cam.num_images.get() == int(
scaninfo["num_points"] * scaninfo["frames_per_trigger"] # scaninfo["num_points"] * scaninfo["frames_per_trigger"]
) # )
assert mock_det.cam.num_frames.get() == 1 # assert mock_det.cam.num_frames.get() == 1
mock_publish_file_location.assert_called_with(done=False) # mock_publish_file_location.assert_called_with(done=False)
assert mock_det.cam.acquire.get() == 1 # assert mock_det.cam.acquire.get() == 1
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"scaninfo, daq_status, expected_exception", # "scaninfo, daq_status, expected_exception",
[ # [
( # (
{ # {
"eacc": "e12345", # "eacc": "e12345",
"num_points": 500, # "num_points": 500,
"frames_per_trigger": 1, # "frames_per_trigger": 1,
"filepath": "test.h5", # "filepath": "test.h5",
"scanID": "123", # "scanID": "123",
}, # },
{"state": "BUSY", "acquisition": {"state": "WAITING_IMAGES"}}, # {"state": "BUSY", "acquisition": {"state": "WAITING_IMAGES"}},
False, # False,
), # ),
( # (
{ # {
"eacc": "e12345", # "eacc": "e12345",
"num_points": 500, # "num_points": 500,
"frames_per_trigger": 1, # "frames_per_trigger": 1,
"filepath": "test.h5", # "filepath": "test.h5",
"scanID": "123", # "scanID": "123",
}, # },
{"state": "BUSY", "acquisition": {"state": "WAITING_IMAGES"}}, # {"state": "BUSY", "acquisition": {"state": "WAITING_IMAGES"}},
False, # False,
), # ),
( # (
{ # {
"eacc": "e12345", # "eacc": "e12345",
"num_points": 500, # "num_points": 500,
"frames_per_trigger": 1, # "frames_per_trigger": 1,
"filepath": "test.h5", # "filepath": "test.h5",
"scanID": "123", # "scanID": "123",
}, # },
{"state": "BUSY", "acquisition": {"state": "ERROR"}}, # {"state": "BUSY", "acquisition": {"state": "ERROR"}},
True, # True,
), # ),
], # ],
) # )
def test_prep_file_writer(mock_det, scaninfo, daq_status, expected_exception): # def test_prep_file_writer(mock_det, scaninfo, daq_status, expected_exception):
with mock.patch.object(mock_det, "std_client") as mock_std_daq, mock.patch.object( # with mock.patch.object(mock_det, "std_client") as mock_std_daq, mock.patch.object(
mock_det, "_filepath_exists" # mock_det, "_filepath_exists"
) as mock_file_path_exists, mock.patch.object( # ) as mock_file_path_exists, mock.patch.object(
mock_det, "_stop_file_writer" # mock_det, "_stop_file_writer"
) as mock_stop_file_writer, mock.patch.object( # ) as mock_stop_file_writer, mock.patch.object(
mock_det, "scaninfo" # mock_det, "scaninfo"
) as mock_scaninfo: # ) as mock_scaninfo:
# mock_det = eiger_factory(name, prefix, sim_mode) # # mock_det = eiger_factory(name, prefix, sim_mode)
mock_det.std_client = mock_std_daq # mock_det.std_client = mock_std_daq
mock_std_daq.start_writer_async.return_value = None # mock_std_daq.start_writer_async.return_value = None
mock_std_daq.get_status.return_value = daq_status # mock_std_daq.get_status.return_value = daq_status
mock_det.filewriter.compile_full_filename.return_value = scaninfo["filepath"] # mock_det.filewriter.compile_full_filename.return_value = scaninfo["filepath"]
mock_det.scaninfo.num_points = scaninfo["num_points"] # mock_det.scaninfo.num_points = scaninfo["num_points"]
mock_det.scaninfo.frames_per_trigger = scaninfo["frames_per_trigger"] # mock_det.scaninfo.frames_per_trigger = scaninfo["frames_per_trigger"]
if expected_exception: # if expected_exception:
with pytest.raises(Exception): # with pytest.raises(Exception):
mock_det._prep_file_writer() # mock_det._prep_file_writer()
mock_file_path_exists.assert_called_once() # mock_file_path_exists.assert_called_once()
assert mock_stop_file_writer.call_count == 2 # assert mock_stop_file_writer.call_count == 2
else: # else:
mock_det._prep_file_writer() # mock_det._prep_file_writer()
mock_file_path_exists.assert_called_once() # mock_file_path_exists.assert_called_once()
mock_stop_file_writer.assert_called_once() # mock_stop_file_writer.assert_called_once()
daq_writer_call = { # daq_writer_call = {
"output_file": scaninfo["filepath"], # "output_file": scaninfo["filepath"],
"n_images": int(scaninfo["num_points"] * scaninfo["frames_per_trigger"]), # "n_images": int(scaninfo["num_points"] * scaninfo["frames_per_trigger"]),
} # }
mock_std_daq.start_writer_async.assert_called_with(daq_writer_call) # mock_std_daq.start_writer_async.assert_called_with(daq_writer_call)
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"stopped, expected_exception", # "stopped, expected_exception",
[ # [
( # (
False, # False,
False, # False,
), # ),
( # (
True, # True,
True, # True,
), # ),
], # ],
) # )
def test_unstage( # def test_unstage(
mock_det, # mock_det,
stopped, # stopped,
expected_exception, # expected_exception,
): # ):
with mock.patch.object(mock_det, "_finished") as mock_finished, mock.patch.object( # with mock.patch.object(mock_det, "_finished") as mock_finished, mock.patch.object(
mock_det, "_publish_file_location" # mock_det, "_publish_file_location"
) as mock_publish_file_location: # ) as mock_publish_file_location:
mock_det._stopped = stopped # mock_det._stopped = stopped
if expected_exception: # if expected_exception:
mock_det.unstage() # mock_det.unstage()
assert mock_det._stopped == True # assert mock_det._stopped == True
else: # else:
mock_det.unstage() # mock_det.unstage()
mock_finished.assert_called_once() # mock_finished.assert_called_once()
mock_publish_file_location.assert_called_with(done=True, successful=True) # mock_publish_file_location.assert_called_with(done=True, successful=True)
assert mock_det._stopped == False # assert mock_det._stopped == False
def test_stop_fw(mock_det): # def test_stop_fw(mock_det):
with mock.patch.object(mock_det, "std_client") as mock_std_daq: # with mock.patch.object(mock_det, "std_client") as mock_std_daq:
mock_std_daq.stop_writer.return_value = None # mock_std_daq.stop_writer.return_value = None
mock_det.std_client = mock_std_daq # mock_det.std_client = mock_std_daq
mock_det._stop_file_writer() # mock_det._stop_file_writer()
mock_std_daq.stop_writer.assert_called_once() # mock_std_daq.stop_writer.assert_called_once()
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"scaninfo", # "scaninfo",
[ # [
({"filepath": "test.h5", "successful": True, "done": False, "scanID": "123"}), # ({"filepath": "test.h5", "successful": True, "done": False, "scanID": "123"}),
({"filepath": "test.h5", "successful": False, "done": True, "scanID": "123"}), # ({"filepath": "test.h5", "successful": False, "done": True, "scanID": "123"}),
({"filepath": "test.h5", "successful": None, "done": True, "scanID": "123"}), # ({"filepath": "test.h5", "successful": None, "done": True, "scanID": "123"}),
], # ],
) # )
def test_publish_file_location(mock_det, scaninfo): # def test_publish_file_location(mock_det, scaninfo):
mock_det.scaninfo.scanID = scaninfo["scanID"] # mock_det.scaninfo.scanID = scaninfo["scanID"]
mock_det.filepath = scaninfo["filepath"] # mock_det.filepath = scaninfo["filepath"]
mock_det._publish_file_location(done=scaninfo["done"], successful=scaninfo["successful"]) # mock_det._publish_file_location(done=scaninfo["done"], successful=scaninfo["successful"])
if scaninfo["successful"] is None: # if scaninfo["successful"] is None:
msg = BECMessage.FileMessage(file_path=scaninfo["filepath"], done=scaninfo["done"]).dumps() # msg = BECMessage.FileMessage(file_path=scaninfo["filepath"], done=scaninfo["done"]).dumps()
else: # else:
msg = BECMessage.FileMessage( # msg = BECMessage.FileMessage(
file_path=scaninfo["filepath"], done=scaninfo["done"], successful=scaninfo["successful"] # file_path=scaninfo["filepath"], done=scaninfo["done"], successful=scaninfo["successful"]
).dumps() # ).dumps()
expected_calls = [ # expected_calls = [
mock.call( # mock.call(
MessageEndpoints.public_file(scaninfo["scanID"], mock_det.name), # MessageEndpoints.public_file(scaninfo["scanID"], mock_det.name),
msg, # msg,
pipe=mock_det._producer.pipeline.return_value, # pipe=mock_det._producer.pipeline.return_value,
), # ),
mock.call( # mock.call(
MessageEndpoints.file_event(mock_det.name), # MessageEndpoints.file_event(mock_det.name),
msg, # msg,
pipe=mock_det._producer.pipeline.return_value, # pipe=mock_det._producer.pipeline.return_value,
), # ),
] # ]
assert mock_det._producer.set_and_publish.call_args_list == expected_calls # assert mock_det._producer.set_and_publish.call_args_list == expected_calls
def test_stop(mock_det): # def test_stop(mock_det):
with mock.patch.object(mock_det, "_stop_det") as mock_stop_det, mock.patch.object( # with mock.patch.object(mock_det, "_stop_det") as mock_stop_det, mock.patch.object(
mock_det, "_stop_file_writer" # mock_det, "_stop_file_writer"
) as mock_stop_file_writer: # ) as mock_stop_file_writer:
mock_det.stop() # mock_det.stop()
mock_stop_det.assert_called_once() # mock_stop_det.assert_called_once()
mock_stop_file_writer.assert_called_once() # mock_stop_file_writer.assert_called_once()
assert mock_det._stopped == True # assert mock_det._stopped == True
@pytest.mark.parametrize( # @pytest.mark.parametrize(
"stopped, scaninfo, cam_state, daq_status, expected_exception", # "stopped, scaninfo, cam_state, daq_status, expected_exception",
[ # [
( # (
False, # False,
{ # {
"num_points": 500, # "num_points": 500,
"frames_per_trigger": 4, # "frames_per_trigger": 4,
}, # },
0, # 0,
{"acquisition": {"state": "FINISHED", "stats": {"n_write_completed": 2000}}}, # {"acquisition": {"state": "FINISHED", "stats": {"n_write_completed": 2000}}},
False, # False,
), # ),
( # (
False, # False,
{ # {
"num_points": 500, # "num_points": 500,
"frames_per_trigger": 4, # "frames_per_trigger": 4,
}, # },
0, # 0,
{"acquisition": {"state": "FINISHED", "stats": {"n_write_completed": 1999}}}, # {"acquisition": {"state": "FINISHED", "stats": {"n_write_completed": 1999}}},
True, # True,
), # ),
( # (
False, # False,
{ # {
"num_points": 500, # "num_points": 500,
"frames_per_trigger": 1, # "frames_per_trigger": 1,
}, # },
1, # 1,
{"acquisition": {"state": "READY", "stats": {"n_write_completed": 500}}}, # {"acquisition": {"state": "READY", "stats": {"n_write_completed": 500}}},
True, # True,
), # ),
( # (
False, # False,
{ # {
"num_points": 500, # "num_points": 500,
"frames_per_trigger": 1, # "frames_per_trigger": 1,
}, # },
0, # 0,
{"acquisition": {"state": "FINISHED", "stats": {"n_write_completed": 500}}}, # {"acquisition": {"state": "FINISHED", "stats": {"n_write_completed": 500}}},
False, # False,
), # ),
], # ],
) # )
def test_finished(mock_det, stopped, cam_state, daq_status, scaninfo, expected_exception): # def test_finished(mock_det, stopped, cam_state, daq_status, scaninfo, expected_exception):
with mock.patch.object(mock_det, "std_client") as mock_std_daq, mock.patch.object( # with mock.patch.object(mock_det, "std_client") as mock_std_daq, mock.patch.object(
mock_det, "_stop_file_writer" # mock_det, "_stop_file_writer"
) as mock_stop_file_friter, mock.patch.object(mock_det, "_stop_det") as mock_stop_det: # ) as mock_stop_file_friter, mock.patch.object(mock_det, "_stop_det") as mock_stop_det:
mock_std_daq.get_status.return_value = daq_status # mock_std_daq.get_status.return_value = daq_status
mock_det.cam.acquire.put(cam_state) # mock_det.cam.acquire.put(cam_state)
mock_det.scaninfo.num_points = scaninfo["num_points"] # mock_det.scaninfo.num_points = scaninfo["num_points"]
mock_det.scaninfo.frames_per_trigger = scaninfo["frames_per_trigger"] # mock_det.scaninfo.frames_per_trigger = scaninfo["frames_per_trigger"]
if expected_exception: # if expected_exception:
with pytest.raises(Exception): # with pytest.raises(Exception):
mock_det._finished() # mock_det._finished()
assert mock_det._stopped == stopped # assert mock_det._stopped == stopped
mock_stop_file_friter.assert_called() # mock_stop_file_friter.assert_called()
mock_stop_det.assert_called_once() # mock_stop_det.assert_called_once()
else: # else:
mock_det._finished() # mock_det._finished()
if stopped: # if stopped:
assert mock_det._stopped == stopped # assert mock_det._stopped == stopped
mock_stop_file_friter.assert_called() # mock_stop_file_friter.assert_called()
mock_stop_det.assert_called_once() # mock_stop_det.assert_called_once()

View File

@ -1,3 +1,9 @@
from bec_lib.core.devicemanager import DeviceContainer
from bec_lib.core.tests.utils import ProducerMock
from unittest import mock
class SocketMock: class SocketMock:
def __init__(self, host, port): def __init__(self, host, port):
self.host = host self.host = host
@ -44,3 +50,246 @@ class SocketMock:
def flush_buffer(self): def flush_buffer(self):
self.buffer_put = [] self.buffer_put = []
self.buffer_recv = "" self.buffer_recv = ""
class MockPV:
"""
MockPV class
This class is used for mocking pyepics signals for testing purposes
"""
_fmtsca = "<PV '%(pvname)s', count=%(count)i, type=%(typefull)s, access=%(access)s>"
_fmtarr = "<PV '%(pvname)s', count=%(count)i/%(nelm)i, type=%(typefull)s, access=%(access)s>"
_fields = (
"pvname",
"value",
"char_value",
"status",
"ftype",
"chid",
"host",
"count",
"access",
"write_access",
"read_access",
"severity",
"timestamp",
"posixseconds",
"nanoseconds",
"precision",
"units",
"enum_strs",
"upper_disp_limit",
"lower_disp_limit",
"upper_alarm_limit",
"lower_alarm_limit",
"lower_warning_limit",
"upper_warning_limit",
"upper_ctrl_limit",
"lower_ctrl_limit",
)
def __init__(
self,
pvname,
callback=None,
form="time",
verbose=False,
auto_monitor=None,
count=None,
connection_callback=None,
connection_timeout=None,
access_callback=None,
):
self.pvname = pvname.strip()
self.form = form.lower()
self.verbose = verbose
self._auto_monitor = auto_monitor
self.ftype = None
self.connected = True
self.connection_timeout = connection_timeout
self._user_max_count = count
if self.connection_timeout is None:
self.connection_timeout = 3
self._args = {}.fromkeys(self._fields)
self._args["pvname"] = self.pvname
self._args["count"] = count
self._args["nelm"] = -1
self._args["type"] = "unknown"
self._args["typefull"] = "unknown"
self._args["access"] = "unknown"
self._args["status"] = 0
self.connection_callbacks = []
self.mock_data = 0
if connection_callback is not None:
self.connection_callbacks = [connection_callback]
self.access_callbacks = []
if access_callback is not None:
self.access_callbacks = [access_callback]
self.callbacks = {}
self._put_complete = None
self._monref = None # holder of data returned from create_subscription
self._monref_mask = None
self._conn_started = False
if isinstance(callback, (tuple, list)):
for i, thiscb in enumerate(callback):
if callable(thiscb):
self.callbacks[i] = (thiscb, {})
elif callable(callback):
self.callbacks[0] = (callback, {})
self.chid = None
self.context = mock.MagicMock()
self._cache_key = (pvname, form, self.context)
self._reference_count = 0
for conn_cb in self.connection_callbacks:
conn_cb(pvname=pvname, conn=True, pv=self)
for acc_cb in self.access_callbacks:
acc_cb(True, True, pv=self)
def wait_for_connection(self, timeout=None):
return self.connected
def get_all_metadata_blocking(self, timeout):
md = self._args.copy()
md.pop("value", None)
return md
def get_all_metadata_callback(self, callback, *, timeout):
def get_metadata_thread(pvname):
md = self.get_all_metadata_blocking(timeout=timeout)
callback(pvname, md)
get_metadata_thread(pvname=self.pvname)
def put(
self, value, wait=False, timeout=None, use_complete=False, callback=None, callback_data=None
):
self.mock_data = value
if callback is not None:
callback(None, None, None)
def get_with_metadata(
self,
count=None,
as_string=False,
as_numpy=True,
timeout=None,
with_ctrlvars=False,
form=None,
use_monitor=True,
as_namespace=False,
):
return {"value": self.mock_data}
def get(
self,
count=None,
as_string=False,
as_numpy=True,
timeout=None,
with_ctrlvars=False,
use_monitor=True,
):
data = self.get_with_metadata(
count=count,
as_string=as_string,
as_numpy=as_numpy,
timeout=timeout,
with_ctrlvars=with_ctrlvars,
use_monitor=use_monitor,
)
return data["value"] if data is not None else None
class DeviceMock:
"""Device Mock. Used for testing in combination with the DeviceManagerMock
Args:
name (str): name of the device
value (float, optional): initial value of the device. Defaults to 0.0.
Returns:
DeviceMock: DeviceMock object
"""
def __init__(self, name: str, value: float = 0.0):
self.name = name
self.read_buffer = value
self._config = {"deviceConfig": {"limits": [-50, 50]}, "userParameter": None}
self._enabled_set = True
self._enabled = True
def read(self):
return {self.name: {"value": self.read_buffer}}
def readback(self):
return self.read_buffer
@property
def enabled_set(self) -> bool:
return self._enabled_set
@enabled_set.setter
def enabled_set(self, val: bool):
self._enabled_set = val
@property
def enabled(self) -> bool:
return self._enabled
@enabled.setter
def enabled(self, val: bool):
self._enabled = val
@property
def user_parameter(self):
return self._config["userParameter"]
@property
def obj(self):
return self
class DMMock:
"""Mock for DeviceManager
The mocked DeviceManager creates a device containert and a producer.
"""
def __init__(self):
self.devices = DeviceContainer()
self.producer = ProducerMock()
def add_device(self, name: str, value: float = 0.0):
self.devices[name] = DeviceMock(name, value)
# #TODO check what is the difference to SynSignal!
# class MockSignal(Signal):
# """Can mock an OphydSignal"""
# def __init__(self, read_pv, *, string=False, name=None, parent=None, **kwargs):
# self.read_pv = read_pv
# self._string = bool(string)
# super().__init__(name=name, parent=parent, **kwargs)
# self._waited_for_connection = False
# self._subscriptions = []
# def wait_for_connection(self):
# self._waited_for_connection = True
# def subscribe(self, method, event_type, **kw):
# self._subscriptions.append((method, event_type, kw))
# def describe_configuration(self):
# return {self.name + "_conf": {"source": "SIM:test"}}
# def read_configuration(self):
# return {self.name + "_conf": {"value": 0}}