diff --git a/tests/test_pilatus_csaxs.py b/tests/test_pilatus_csaxs.py index 74f5322..2d70bfe 100644 --- a/tests/test_pilatus_csaxs.py +++ b/tests/test_pilatus_csaxs.py @@ -1,3 +1,4 @@ +import os import pytest from unittest import mock @@ -358,6 +359,167 @@ def test_stop_file_writer(mock_det, requests_state, expected_exception, url): mock_send_requests_put.assert_called_once_with(url=url) +@pytest.mark.parametrize( + "scaninfo, data_msgs, urls, requests_state, expected_exception", + [ + ( + { + "filepath_raw": "pilatus_2.h5", + "eacc": "e12345", + "scan_number": 1000, + "scan_directory": "S00000_00999", + "num_points": 500, + "frames_per_trigger": 1, + "headers": {"Content-Type": "application/json", "Accept": "application/json"}, + }, + [ + { + "source": [ + { + "searchPath": "/", + "searchPattern": "glob:*.cbf", + "destinationPath": "/sls/X12SA/data/e12345/Data10/pilatus_2/S00000_00999", + } + ] + }, + [ + "zmqWriter", + "e12345", + { + "addr": "tcp://x12sa-pd-2:8888", + "dst": ["file"], + "numFrm": 500, + "timeout": 2000, + "ifType": "PULL", + "user": "e12345", + }, + ], + [ + "zmqWriter", + "e12345", + { + "frmCnt": 500, + "timeout": 2000, + }, + ], + ], + [ + "http://x12sa-pd-2:8080/stream/pilatus_2", + "http://xbl-daq-34:8091/pilatus_2/run", + "http://xbl-daq-34:8091/pilatus_2/wait", + ], + True, + False, + ), + ( + { + "filepath_raw": "pilatus_2.h5", + "eacc": "e12345", + "scan_number": 1000, + "scan_directory": "S00000_00999", + "num_points": 500, + "frames_per_trigger": 1, + "headers": {"Content-Type": "application/json", "Accept": "application/json"}, + }, + [ + { + "source": [ + { + "searchPath": "/", + "searchPattern": "glob:*.cbf", + "destinationPath": "/sls/X12SA/data/e12345/Data10/pilatus_2/S00000_00999", + } + ] + }, + [ + "zmqWriter", + "e12345", + { + "addr": "tcp://x12sa-pd-2:8888", + "dst": ["file"], + "numFrm": 500, + "timeout": 2000, + "ifType": "PULL", + "user": "e12345", + }, + ], + [ + "zmqWriter", + "e12345", + { + "frmCnt": 500, + "timeout": 2000, + }, + ], + ], + [ + "http://x12sa-pd-2:8080/stream/pilatus_2", + "http://xbl-daq-34:8091/pilatus_2/run", + "http://xbl-daq-34:8091/pilatus_2/wait", + ], + False, # return of res.ok is False! + True, + ), + ], +) +def test_prep_file_writer(mock_det, scaninfo, data_msgs, urls, requests_state, expected_exception): + with mock.patch.object( + mock_det, "_close_file_writer" + ) as mock_close_file_writer, mock.patch.object( + mock_det, "_stop_file_writer" + ) as mock_stop_file_writer, mock.patch.object( + mock_det, "filewriter" + ) as mock_filewriter, mock.patch.object( + mock_det, "_create_directory" + ) as mock_create_directory, mock.patch.object( + mock_det, "_send_requests_put" + ) as mock_send_requests_put: + mock_det.scaninfo.scan_number = scaninfo["scan_number"] + mock_det.scaninfo.num_points = scaninfo["num_points"] + mock_det.scaninfo.frames_per_trigger = scaninfo["frames_per_trigger"] + mock_det.scaninfo.username = scaninfo["eacc"] + mock_filewriter.compile_full_filename.return_value = scaninfo["filepath_raw"] + mock_filewriter.get_scan_directory.return_value = scaninfo["scan_directory"] + instance = mock_send_requests_put.return_value + instance.ok = requests_state + instance.raise_for_status.side_effect = Exception + + if expected_exception: + with pytest.raises(Exception): + mock_det._prep_file_writer() + mock_close_file_writer.assert_called_once() + mock_stop_file_writer.assert_called_once() + instance.raise_for_status.assert_called_once() + else: + mock_det._prep_file_writer() + + mock_close_file_writer.assert_called_once() + mock_stop_file_writer.assert_called_once() + + # Assert values set on detector + assert mock_det.cam.file_path.get() == "/dev/shm/zmq/" + assert ( + mock_det.cam.file_name.get() + == f"{scaninfo['eacc']}_2_{scaninfo['scan_number']:05d}" + ) + assert mock_det.cam.auto_increment.get() == 1 + assert mock_det.cam.file_number.get() == 0 + assert mock_det.cam.file_format.get() == 0 + assert mock_det.cam.file_template.get() == "%s%s_%5.5d.cbf" + # Remove last / from destinationPath + mock_create_directory.assert_called_once_with( + os.path.join(data_msgs[0]["source"][0]["destinationPath"]) + ) + assert mock_send_requests_put.call_count == 3 + + calls = [ + mock.call(url=url, data=data_msg, headers=scaninfo["headers"]) + for url, data_msg in zip(urls, data_msgs) + ] + for call, mock_call in zip(calls, mock_send_requests_put.call_args_list): + assert call == mock_call + + # @pytest.mark.parametrize( # "scaninfo, daq_status, expected_exception", # [