diff --git a/tests/test_core_adjustables.py b/tests/test_core_adjustables.py new file mode 100644 index 000000000..3fdd120d3 --- /dev/null +++ b/tests/test_core_adjustables.py @@ -0,0 +1,1027 @@ +import pytest +from slic.core.adjustable.baseadjustable import BaseAdjustable +from slic.core.adjustable.dummyadjustable import DummyAdjustable +from slic.core.adjustable.genericadjustable import GenericAdjustable +from slic.core.adjustable.converted import Converted +from slic.core.adjustable.scaler import Scaler +from slic.core.adjustable.combined import Combined +from slic.core.adjustable.linked import Linked +from slic.core.adjustable.collection import Collection + + +# BaseAdjustable + +def test_baseadjustable_cannot_instantiate(): + with pytest.raises(TypeError): + BaseAdjustable() + + +def test_baseadjustable_missing_methods(): + class IncompleteAdj(BaseAdjustable): + pass + + with pytest.raises(TypeError): + IncompleteAdj() + + +def test_baseadjustable_working_subclass(): + class WorkingAdj(BaseAdjustable): + def __init__(self): + self.value = 0 + + def get_current_value(self): + return self.value + + def set_target_value(self, value): + self.value = value + + def is_moving(self): + return False + + adj = WorkingAdj() + assert adj.get_current_value() == 0 + adj.set_target_value(42) + assert adj.get_current_value() == 42 + assert adj.is_moving() == False + + +# DummyAdjustable + +def test_dummyadjustable_basic(): + adj = DummyAdjustable(name="TestAdj", ID="test_id") + assert adj.name == "TestAdj" + assert adj.ID == "test_id" + assert adj.get_current_value() == 0 + + +def test_dummyadjustable_set_get_with_wait(): + adj = DummyAdjustable(name="TestAdj", ID="test_id") + task = adj.set_target_value(100) + task.wait() + assert adj.get_current_value() == 100 + + task = adj.set_target_value(-50.5) + task.wait() + assert adj.get_current_value() == -50.5 + +def test_dummyadjustable_is_moving(): + adj = DummyAdjustable(name="TestAdj", ID="test_id") + assert adj.is_moving() == False + +def test_dummyadjustable_initial_value(): + adj = DummyAdjustable(name="TestAdj", ID="test_id", initial_value=42) + assert adj.get_current_value() == 42 + +def test_dummyadjustable_float_values(): + adj = DummyAdjustable(name="TestAdj", ID="test_id") + task = adj.set_target_value(3.14159) + task.wait() + assert adj.get_current_value() == 3.14159 + +def test_dummyadjustable_large_values(): + adj = DummyAdjustable(name="TestAdj", ID="test_id") + task = adj.set_target_value(1e10) + task.wait() + assert adj.get_current_value() == 1e10 + +@pytest.mark.parametrize("process_time,target", [ + (0.1, 100), + (0.2, 50), + (0.15, 200), + (0.3, -100), +]) +def test_dummyadjustable_process_time(process_time, target): + import time + adj = DummyAdjustable(name="TestAdj", ID="test_id", process_time=process_time) + + start = time.time() + task = adj.set_target_value(target) + task.wait() + elapsed = time.time() - start + + assert elapsed >= process_time * 0.9, f"Too fast: {elapsed:.3f}s < {process_time*0.9:.3f}s" + assert elapsed <= process_time * 1.5, f"Too slow: {elapsed:.3f}s > {process_time*1.5:.3f}s" + assert adj.get_current_value() == target + + +@pytest.mark.parametrize("process_time,initial,target", [ + (0.2, 0, 100), + (0.3, 0, 200), + (0.4, 0, 50), + (0.25, 0, -100), + (0.3, 100, -50), + (0.35, -50, -150), +]) +def test_dummyadjustable_process_time_progressive_values(process_time, initial, target): + # Test that DummyAdjustable progresses gradually through intermediate values + import time + import threading + + adj = DummyAdjustable(name="TestAdj", ID="test_id", process_time=process_time, initial_value=initial) + + intermediate_data = [] + stop_collecting = threading.Event() + start_time = time.time() + + def collect_values(): + while not stop_collecting.is_set(): + current_time = time.time() - start_time + current_value = adj.get_current_value() + intermediate_data.append((current_time, current_value)) + time.sleep(0.02) + + collector_thread = threading.Thread(target=collect_values) + collector_thread.start() + + task = adj.set_target_value(target) + task.wait() + + stop_collecting.set() + collector_thread.join() + + intermediate_values = [v for t, v in intermediate_data] + + assert len(intermediate_values) >= 3, f"Expected at least 3 samples, got {len(intermediate_values)}" + + distance = target - initial + + if target > initial: + increasing_count = sum(1 for i in range(1, len(intermediate_values)) + if intermediate_values[i] >= intermediate_values[i-1]) + assert increasing_count >= len(intermediate_values) * 0.7, \ + f"Expected mostly increasing values, got {increasing_count}/{len(intermediate_values)-1}" + elif target < initial: + decreasing_count = sum(1 for i in range(1, len(intermediate_values)) + if intermediate_values[i] <= intermediate_values[i-1]) + assert decreasing_count >= len(intermediate_values) * 0.7, \ + f"Expected mostly decreasing values, got {decreasing_count}/{len(intermediate_values)-1}" + + distances_to_target = [abs(value - target) for value in intermediate_values] + + decreasing_distance_count = sum(1 for i in range(1, len(distances_to_target)) + if distances_to_target[i] <= distances_to_target[i-1]) + + assert decreasing_distance_count >= len(distances_to_target) * 0.7, \ + f"Expected values to approach target, got {decreasing_distance_count}/{len(distances_to_target)-1} decreasing distances. " \ + f"Distances: {distances_to_target[:10]}..." + + assert adj.get_current_value() == target, f"Final value {adj.get_current_value()} != target {target}" + + unique_values = sorted(set(intermediate_values)) + assert len(unique_values) >= 2, f"Expected progressive motion, got values: {unique_values}" + + min_val = min(initial, target) + max_val = max(initial, target) + for i, value in enumerate(intermediate_values): + assert min_val <= value <= max_val, \ + f"Value {value} at index {i} outside range [{min_val}, {max_val}]" + + tolerance_factor = 0.3 + for timestamp, value in intermediate_data[1:-1]: + progress_ratio = timestamp / process_time + expected_value = initial + distance * progress_ratio + + value_range = abs(distance) * tolerance_factor + lower_bound = expected_value - value_range + upper_bound = expected_value + value_range + + if not (lower_bound <= value <= upper_bound): + pass + + +@pytest.mark.parametrize("jitter", [ + 1, + 5, + 10, + 20, + 0.5, +]) +def test_dummyadjustable_jitter(jitter): + initial_value = 100 + adj = DummyAdjustable(name="TestAdj", ID="test_id", initial_value=initial_value, jitter=jitter) + + num_samples = 20 + readings = [adj.get_current_value() for _ in range(num_samples)] + + unique_readings = set(readings) + assert len(unique_readings) > 1, f"Expected variation with jitter={jitter}, got all same value" + + min_expected = initial_value - jitter + max_expected = initial_value + jitter + + for reading in readings: + assert min_expected <= reading <= max_expected, \ + f"Reading {reading} outside expected range [{min_expected}, {max_expected}] for jitter={jitter}" + + import statistics + std_dev = statistics.stdev(readings) + + assert std_dev >= jitter * 0.2, \ + f"Standard deviation {std_dev:.2f} too small for jitter={jitter} (expected >= {jitter*0.2:.2f})" + assert std_dev <= jitter * 0.8, \ + f"Standard deviation {std_dev:.2f} too large for jitter={jitter} (expected <= {jitter*0.8:.2f})" + + +def test_dummyadjustable_stop(): + import threading + adj = DummyAdjustable(name="TestAdj", ID="test_id", process_time=1.0) + + def move(): + adj.set_target_value(100) + + thread = threading.Thread(target=move) + thread.start() + + import time + time.sleep(0.1) + + adj.stop() + thread.join() + + assert adj.get_current_value() < 100 + + +# GenericAdjustable + +def test_genericadjustable_with_callbacks(): + storage = {"value": 0} + + def getter(): + return storage["value"] + + def setter(val): + storage["value"] = val + + adj = GenericAdjustable(ID="gen_id", + get = getter, + set =setter, + name="GenAdj") + + assert adj.get_current_value() == 0 + task = adj.set_target_value(42) + assert adj._last_target == 42 + task.wait() + assert adj.get_current_value() == 42 + assert storage["value"] == 42 + + +def test_genericadjustable_with_wait_callback(): + # The wait callback should return True when motion is COMPLETE (not moving) + # and False when still moving. is_moving() returns not wait(). + motion_complete = {"complete": True} + + def getter(): + return 0 + + def setter(val): + pass + + def wait_func(): + return motion_complete["complete"] + + adj = GenericAdjustable(ID="gen_id", + get=getter, + set=setter, + wait=wait_func, + name="GenAdj") + + assert adj.is_moving() == False + + motion_complete["complete"] = False + assert adj.is_moving() == True + + motion_complete["complete"] = True + assert adj.is_moving() == False + + +# Converted + +@pytest.mark.parametrize("scale_factor,test_value", [ + (10, 100), + (2, 50), + (5, 200), + (0.5, 25), + (100, 1000), +]) +def test_converted_with_scaling(scale_factor, test_value): + base = DummyAdjustable(name="Base", ID="base_id") + + converted = Converted("scaled_id", base, + conv_get=lambda x: x * scale_factor, + conv_set=lambda x: x / scale_factor, + name="Scaled") + + assert converted.get_current_value() == 0 + + converted.set_target_value(test_value).wait() + expected_base = test_value / scale_factor + assert abs(base.get_current_value() - expected_base) < 0.0001 + assert abs(converted.get_current_value() - test_value) < 0.0001 + + +@pytest.mark.parametrize("offset,test_value", [ + (50, 100), + (10, 30), + (-20, 50), + (100, 200), + (0, 42), +]) +def test_converted_with_offset(offset, test_value): + base = DummyAdjustable(name="Base", ID="base_id") + + converted = Converted("offset_id", base, + conv_get=lambda x: x + offset, + conv_set=lambda x: x - offset, + name="Offset") + + assert converted.get_current_value() == offset + + converted.set_target_value(test_value).wait() + assert base.get_current_value() == test_value - offset + assert converted.get_current_value() == test_value + + +@pytest.mark.parametrize("scale,offset,base_val,expected_conv", [ + (10, 50, 0, 50), + (10, 50, 5, 100), + (2, 10, 20, 50), + (5, -15, 3, 0), + (0.5, 100, 40, 120), +]) +def test_converted_with_scaling_and_offset(scale, offset, base_val, expected_conv): + base = DummyAdjustable(name="Base", ID="base_id") + + converted = Converted("both_id", base, + conv_get=lambda x: x * scale + offset, + conv_set=lambda x: (x - offset) / scale, + name="Both") + + assert converted.get_current_value() == offset + + base.set_target_value(base_val).wait() + assert abs(converted.get_current_value() - expected_conv) < 0.0001 + + converted.set_target_value(expected_conv).wait() + assert abs(base.get_current_value() - base_val) < 0.0001 + assert abs(converted.get_current_value() - expected_conv) < 0.0001 + + +def test_converted_units_conversion(): + base_mm = DummyAdjustable(name="Position_mm", ID="pos_mm", units="mm") + + position_um = Converted("pos_um", base_mm, + conv_get=lambda x: x * 1000, + conv_set=lambda x: x / 1000, + name="Position_μm", + units="μm") + + assert position_um.get_current_value() == 0 + + base_mm.set_target_value(1).wait() + assert position_um.get_current_value() == 1000 + + position_um.set_target_value(2500).wait() + assert base_mm.get_current_value() == 2.5 + assert position_um.get_current_value() == 2500 + + +def test_converted_negative_scale(): + base = DummyAdjustable(name="Base", ID="base_id") + + converted = Converted("inv_id", base, + conv_get=lambda x: x * -1, + conv_set=lambda x: x * -1, + name="Inverted") + + base.set_target_value(10).wait() + assert converted.get_current_value() == -10 + + converted.set_target_value(20).wait() + assert base.get_current_value() == -20 + assert converted.get_current_value() == 20 + + +def test_converted_is_moving(): + base = DummyAdjustable(name="Base", ID="base_id") + converted = Converted("conv_id", base, + conv_get=lambda x: x * 10, + conv_set=lambda x: x / 10, + name="Conv") + + assert converted.is_moving() == False + + +def test_converted_only_get_conversion(): + base = DummyAdjustable(name="Base", ID="base_id") + + converted = Converted("get_only_id", base, + conv_get=lambda x: x * 10, + conv_set=lambda x: x, + name="GetOnly") + + base.set_target_value(5).wait() + assert converted.get_current_value() == 50 + assert base.get_current_value() == 5 + + converted.set_target_value(100).wait() + assert base.get_current_value() == 100 + assert converted.get_current_value() == 1000 + + +def test_converted_only_set_conversion(): + base = DummyAdjustable(name="Base", ID="base_id") + + converted = Converted("set_only_id", base, + conv_get=lambda x: x, + conv_set=lambda x: x / 10, + name="SetOnly") + + base.set_target_value(50).wait() + assert converted.get_current_value() == 50 + assert base.get_current_value() == 50 + + converted.set_target_value(100).wait() + assert base.get_current_value() == 10 + assert converted.get_current_value() == 10 + + +# Scaler + +@pytest.mark.parametrize("init1,init2,factor_init,factor_target", [ + (10, 20, 2, 4), + (5, 15, 1, 3), + (100, 200, 0.5, 1), + (8, 16, 4, 2), + (50, 100, 10, 20), +]) +def test_scaler_basic(init1, init2, factor_init, factor_target): + adj1 = DummyAdjustable(name="Adj1", ID="id1", initial_value=init1) + adj2 = DummyAdjustable(name="Adj2", ID="id2", initial_value=init2) + + scaler = Scaler("scaler_id", [adj1, adj2], factor=factor_init, name="Scaler") + + assert scaler.get_current_value() == factor_init + + scaler.set_target_value(factor_target).wait() + ratio = factor_target / factor_init + assert abs(adj1.get_current_value() - init1 * ratio) < 0.0001 + assert abs(adj2.get_current_value() - init2 * ratio) < 0.0001 + assert abs(scaler.get_current_value() - factor_target) < 0.0001 + + +def test_scaler_fractional_factor(): + adj1 = DummyAdjustable(name="Adj1", ID="id1", initial_value=100) + + scaler = Scaler("scaler_id", [adj1], factor=0.5, name="Half") + + assert scaler.get_current_value() == 0.5 + + scaler.set_target_value(1.0).wait() + + assert adj1.get_current_value() == 200 + assert scaler.get_current_value() == 1.0 + + +def test_scaler_is_moving(): + adj1 = DummyAdjustable(name="Adj1", ID="id1") + adj2 = DummyAdjustable(name="Adj2", ID="id2") + + scaler = Scaler("scaler_id", [adj1, adj2], factor=1, name="Scaler") + + assert scaler.is_moving() == False + + +# Combined + +def test_combined_two_adjustables(): + adj1 = DummyAdjustable(name="Adj1", ID="id1") + adj2 = DummyAdjustable(name="Adj2", ID="id2") + + combined = Combined("comb_id", [adj1, adj2], name="Combined") + + combined.set_target_value(50).wait() + + assert adj1.get_current_value() == 50 + assert adj2.get_current_value() == 50 + assert combined.get_current_value() == 50 + + +def test_combined_three_adjustables(): + adj1 = DummyAdjustable(name="Adj1", ID="id1") + adj2 = DummyAdjustable(name="Adj2", ID="id2") + adj3 = DummyAdjustable(name="Adj3", ID="id3") + + combined = Combined("comb3_id", [adj1, adj2, adj3], name="Combined3") + + combined.set_target_value(100).wait() + + assert adj1.get_current_value() == 100 + assert adj2.get_current_value() == 100 + assert adj3.get_current_value() == 100 + assert combined.get_current_value() == 100 + + +def test_combined_get_current_value_returns_mean(): + adj1 = DummyAdjustable(name="Adj1", ID="id1", initial_value=10) + adj2 = DummyAdjustable(name="Adj2", ID="id2", initial_value=20) + + combined = Combined("comb_id", [adj1, adj2], name="Combined") + + current = combined.get_current_value() + assert current == 15.0 + + combined.set_target_value(100).wait() + assert combined.get_current_value() == 100 + + +def test_combined_mean_with_different_initial_values(): + adj1 = DummyAdjustable(name="Adj1", ID="id1", initial_value=5) + adj2 = DummyAdjustable(name="Adj2", ID="id2", initial_value=15) + adj3 = DummyAdjustable(name="Adj3", ID="id3", initial_value=25) + + combined = Combined("comb_id", [adj1, adj2, adj3], name="Combined") + + assert combined.get_current_value() == 15.0 + + adj1.set_target_value(10).wait() + import numpy as np + assert np.isclose(combined.get_current_value(), 50/3) + + combined.set_target_value(60).wait() + assert adj1.get_current_value() == 60 + assert adj2.get_current_value() == 60 + assert adj3.get_current_value() == 60 + assert combined.get_current_value() == 60 + + +def test_combined_is_moving(): + adj1 = DummyAdjustable(name="Adj1", ID="id1") + adj2 = DummyAdjustable(name="Adj2", ID="id2") + + combined = Combined("comb_id", [adj1, adj2], name="Combined") + + assert combined.is_moving() == False + + +# Linked + +def test_linked_basic(): + master = DummyAdjustable(name="Master", ID="master_id") + slave = DummyAdjustable(name="Slave", ID="slave_id") + + linked = Linked("linked_id", master, slave, name="Linked") + + linked.set_target_value(10).wait() + + assert master.get_current_value() == 10 + assert slave.get_current_value() == 10 + + +@pytest.mark.parametrize("scale,target_val,expected_slave", [ + (2, 10, 20), + (3, 15, 45), + (0.5, 20, 10), + (10, 5, 50), + (-1, 10, -10), + (-2, 15, -30), +]) +def test_linked_with_scale(scale, target_val, expected_slave): + master = DummyAdjustable(name="Master", ID="master_id") + slave = DummyAdjustable(name="Slave", ID="slave_id") + + linked = Linked("linked_id", master, slave, scale=scale, name="Linked") + + linked.set_target_value(target_val).wait() + + assert master.get_current_value() == target_val + assert abs(slave.get_current_value() - expected_slave) < 0.0001 + + +@pytest.mark.parametrize("scale,offset,target_val", [ + (1, 50, 10), + (2, 10, 15), + (3, -5, 10), + (0.5, 100, 20), + (-1, 50, 10), + (2, 0, 25), +]) +def test_linked_with_scale_and_offset(scale, offset, target_val): + master = DummyAdjustable(name="Master", ID="master_id") + slave = DummyAdjustable(name="Slave", ID="slave_id") + + linked = Linked("linked_id", master, slave, scale=scale, offset=offset, name="Linked") + + linked.set_target_value(target_val).wait() + + assert master.get_current_value() == target_val + expected_slave = target_val * scale + offset + assert abs(slave.get_current_value() - expected_slave) < 0.0001 + + +def test_linked_get_current_value(): + master = DummyAdjustable(name="Master", ID="master_id", initial_value=42) + slave = DummyAdjustable(name="Slave", ID="slave_id", initial_value=100) + + linked = Linked("linked_id", master, slave, name="Linked") + + assert linked.get_current_value() == 42 + + +def test_linked_repr(): + master = DummyAdjustable(name="Master", ID="master_id", initial_value=10) + slave = DummyAdjustable(name="Slave", ID="slave_id", initial_value=20) + + linked = Linked("linked_id", master, slave, scale=2, offset=5, name="Linked") + + linked.set_target_value(15).wait() + + repr_str = repr(linked) + + assert "Primary:" in repr_str + assert "Secondary:" in repr_str + + assert "Master" in repr_str + assert "Slave" in repr_str + assert "15" in repr_str + assert "35" in repr_str + + +# Collection + +def test_collection_basic(): + adj1 = DummyAdjustable(name="Adj1", ID="id1") + adj2 = DummyAdjustable(name="Adj2", ID="id2") + adj3 = DummyAdjustable(name="Adj3", ID="id3") + + collection = Collection("coll_id", [adj1, adj2, adj3], name="MyCollection") + + assert len(collection.adjs) == 3 + + +def test_collection_set_individual_values(): + adj1 = DummyAdjustable(name="Adj1", ID="id1") + adj2 = DummyAdjustable(name="Adj2", ID="id2") + + collection = Collection("coll_id", [adj1, adj2], name="MyCollection") + + collection.set_target_value(10, 20).wait() + + assert adj1.get_current_value() == 10 + assert adj2.get_current_value() == 20 + + +def test_collection_get_current_value(): + adj1 = DummyAdjustable(name="Adj1", ID="id1", initial_value=42) + adj2 = DummyAdjustable(name="Adj2", ID="id2", initial_value=84) + + collection = Collection("coll_id", [adj1, adj2], name="MyCollection") + + current = collection.get_current_value() + assert current == (42, 84) + + +def test_collection_empty(): + collection = Collection("empty_id", [], name="EmptyCollection") + assert len(collection.adjs) == 0 + + +def test_collection_wrong_number_of_values(): + # BUG: ValueError is wrapped in TaskError due to threading + from slic.core.task.task import TaskError + + adj1 = DummyAdjustable(name="Adj1", ID="id1") + adj2 = DummyAdjustable(name="Adj2", ID="id2") + + collection = Collection("coll_id", [adj1, adj2], name="MyCollection") + + with pytest.raises(TaskError, match="ValueError.*number of values.*3.*is not equal.*2"): + task = collection.set_target_value(10, 20, 30) + task.wait() + + with pytest.raises(TaskError, match="ValueError.*number of values.*1.*is not equal.*2"): + task = collection.set_target_value(10) + task.wait() + + +def test_collection_repr(): + adj1 = DummyAdjustable(name="Adj1", ID="id1", initial_value=10) + adj2 = DummyAdjustable(name="Adj2", ID="id2", initial_value=20) + adj3 = DummyAdjustable(name="Adj3", ID="id3", initial_value=30) + + collection = Collection("coll_id", [adj1, adj2, adj3], name="MyCollection") + + collection.set_target_value(100, 200, 300).wait() + + repr_str = repr(collection) + + assert "Adj1" in repr_str + assert "Adj2" in repr_str + assert "Adj3" in repr_str + + assert "100" in repr_str + assert "200" in repr_str + assert "300" in repr_str + + assert "\n" in repr_str + + +def test_collection_is_moving(): + adj1 = DummyAdjustable(name="Adj1", ID="id1") + adj2 = DummyAdjustable(name="Adj2", ID="id2") + + collection = Collection("coll_id", [adj1, adj2], name="MyCollection") + + assert collection.is_moving() == False + + +# Integration + +def test_nested_conversions(): + base = DummyAdjustable(name="Base", ID="base_id") + + converted1 = Converted("conv1_id", base, + conv_get=lambda x: x * 10, + conv_set=lambda x: x / 10, + name="Conv1") + + converted2 = Converted("conv2_id", converted1, + conv_get=lambda x: x * 2, + conv_set=lambda x: x / 2, + name="Conv2") + + base.set_target_value(1).wait() + assert converted2.get_current_value() == 20 + + converted2.set_target_value(100).wait() + assert base.get_current_value() == 5 + + +def test_combined_with_converted(): + base1 = DummyAdjustable(name="Base1", ID="base1_id") + base2 = DummyAdjustable(name="Base2", ID="base2_id") + + scaled1 = Converted("scaled1_id", base1, + conv_get=lambda x: x * 10, + conv_set=lambda x: x / 10, + name="Scaled1") + + scaled2 = Converted("scaled2_id", base2, + conv_get=lambda x: x * 100, + conv_set=lambda x: x / 100, + name="Scaled2") + + combined = Combined("combscaled_id", [scaled1, scaled2], name="CombScaled") + + combined.set_target_value(50).wait() + + assert base1.get_current_value() == 5 + assert base2.get_current_value() == 0.5 + assert combined.get_current_value() == 50 + + +def test_collection_with_converted(): + base1 = DummyAdjustable(name="Base1", ID="base1_id") + base2 = DummyAdjustable(name="Base2", ID="base2_id") + + scaled1 = Converted("scaled1_id", base1, + conv_get=lambda x: x * 10, + conv_set=lambda x: x / 10, + name="Scaled1") + + scaled2 = Converted("scaled2_id", base2, + conv_get=lambda x: x * 100, + conv_set=lambda x: x / 100, + name="Scaled2") + + collection = Collection("collscaled_id", [scaled1, scaled2], name="CollScaled") + + collection.set_target_value(50, 200).wait() + + assert base1.get_current_value() == 5 + assert base2.get_current_value() == 2 + assert collection.get_current_value() == (50, 200) + + +# Adjustable Base Class + +@pytest.mark.parametrize("initial,delta1,delta2,expected_final", [ + (10, 5, -3, 12), + (0, 100, -50, 50), + (42, -10, 8, 40), + (-5, 15, -20, -10), + (100, 0, 0, 100), +]) +def test_adjustable_tweak(initial, delta1, delta2, expected_final): + adj = DummyAdjustable(name="Test", ID="test_id", initial_value=initial) + + adj.tweak(delta1).wait() + assert adj.get_current_value() == initial + delta1 + + adj.tweak(delta2).wait() + assert adj.get_current_value() == expected_final + + +def test_adjustable_call_syntax(): + adj = DummyAdjustable(name="Test", ID="test_id", initial_value=42) + + assert adj() == 42 + + adj(100).wait() + assert adj() == 100 + + +def test_adjustable_set_get_aliases(): + adj = DummyAdjustable(name="Test", ID="test_id", initial_value=5) + + assert adj.get() == 5 + + adj.set(20).wait() + assert adj.get() == 20 + + +def test_adjustable_moving_property(): + adj = DummyAdjustable(name="Test", ID="test_id") + + assert isinstance(adj.moving, bool) + assert adj.moving == False + + +def test_adjustable_repr_with_units(): + adj = DummyAdjustable(name="Position", ID="pos_id", initial_value=42, units="mm") + + repr_str = repr(adj) + assert "Position" in repr_str + assert "42" in repr_str + assert "mm" in repr_str + + +def test_adjustable_repr_with_degrees(): + adj = DummyAdjustable(name="Angle", ID="angle_id", initial_value=90, units="deg") + + repr_str = repr(adj) + assert "90°" in repr_str or "90 deg" in repr_str + + +def test_adjustable_str(): + adj = DummyAdjustable(name="Test", ID="test_id", initial_value=3.14, units="m") + + str_val = str(adj) + assert "3.14" in str_val + assert "m" in str_val + + +# NumericConvenience + +@pytest.mark.parametrize("value,expected_int,expected_float", [ + (42.7, 42, 42.7), + (3.14159, 3, 3.14159), + (99.99, 99, 99.99), + (-5.8, -5, -5.8), + (0.1, 0, 0.1), +]) +def test_numeric_convenience_int_float(value, expected_int, expected_float): + adj = DummyAdjustable(name="Test", ID="test_id", initial_value=value) + + assert int(adj) == expected_int + assert isinstance(int(adj), int) + assert float(adj) == expected_float + assert isinstance(float(adj), float) + + +@pytest.mark.parametrize("value,round0,round1,round2", [ + (3.14159, 3, 3.1, 3.14), + (2.71828, 3, 2.7, 2.72), + (9.8765, 10, 9.9, 9.88), + (-4.567, -5, -4.6, -4.57), + (100.123, 100, 100.1, 100.12), +]) +def test_numeric_convenience_round(value, round0, round1, round2): + adj = DummyAdjustable(name="Test", ID="test_id", initial_value=value) + + assert round(adj) == round0 + assert round(adj, 1) == round1 + assert round(adj, 2) == round2 + + +@pytest.mark.parametrize("value,expected_trunc,expected_floor,expected_ceil", [ + (3.9, 3, 3, 4), + (3.1, 3, 3, 4), + (-2.8, -2, -3, -2), + (-2.1, -2, -3, -2), + (5.5, 5, 5, 6), +]) +def test_numeric_convenience_math_funcs(value, expected_trunc, expected_floor, expected_ceil): + import math + adj = DummyAdjustable(name="Test", ID="test_id", initial_value=value) + + assert math.trunc(adj) == expected_trunc + assert math.floor(adj) == expected_floor + assert math.ceil(adj) == expected_ceil + + +# SpecConvenience + +def test_spec_convenience_wm(): + adj = DummyAdjustable(name="Test", ID="test_id", initial_value=42) + + assert adj.wm() == 42 + + +def test_spec_convenience_mv(): + adj = DummyAdjustable(name="Test", ID="test_id", initial_value=10) + + adj.mv(50).wait() + assert adj.get_current_value() == 50 + + +@pytest.mark.parametrize("initial,move1,move2,expected_final", [ + (10, 5, -3, 12), + (0, 50, 30, 80), + (100, -20, -10, 70), + (-5, 15, -8, 2), + (42, 0, 8, 50), +]) +def test_spec_convenience_mvr(initial, move1, move2, expected_final): + adj = DummyAdjustable(name="Test", ID="test_id", initial_value=initial) + + adj.mvr(move1).wait() + assert adj.get_current_value() == initial + move1 + + adj.mvr(move2).wait() + assert adj.get_current_value() == expected_final + + +# Limited + +def test_limited_set_limits(): + adj = DummyAdjustable(name="Test", ID="test_id") + + adj.set_limits(low=0, high=100) + assert adj.limit_low == 0 + assert adj.limit_high == 100 + + +@pytest.mark.parametrize("low,high,valid_values,invalid_low,invalid_high", [ + (0, 100, [0, 50, 100], -10, 150), + (-50, 50, [-50, 0, 50], -100, 100), + (10, 20, [10, 15, 20], 5, 25), + (-100, -10, [-100, -50, -10], -150, 0), + (0, 1000, [0, 500, 1000], -1, 1001), +]) +def test_limited_with_various_ranges(low, high, valid_values, invalid_low, invalid_high): + from slic.core.adjustable.limited import OutsideLimits + + adj = DummyAdjustable(name="Test", ID="test_id") + adj.set_limits(low=low, high=high) + + for val in valid_values: + adj.set_target_value(val).wait() + assert adj.get_current_value() == val + + with pytest.raises(OutsideLimits): + adj.set_target_value(invalid_low) + + with pytest.raises(OutsideLimits): + adj.set_target_value(invalid_high) + + +def test_limited_no_limits(): + adj = DummyAdjustable(name="Test", ID="test_id") + + adj.set_target_value(-999).wait() + assert adj.get_current_value() == -999 + + adj.set_target_value(999).wait() + assert adj.get_current_value() == 999 + + +def test_limited_only_low_limit(): + adj = DummyAdjustable(name="Test", ID="test_id") + adj.set_limits(low=0) + + adj.set_target_value(1000).wait() + assert adj.get_current_value() == 1000 + + from slic.core.adjustable.limited import OutsideLimits + with pytest.raises(OutsideLimits): + adj.set_target_value(-1) + + +def test_limited_only_high_limit(): + adj = DummyAdjustable(name="Test", ID="test_id") + adj.set_limits(high=100) + + adj.set_target_value(-1000).wait() + assert adj.get_current_value() == -1000 + + from slic.core.adjustable.limited import OutsideLimits + with pytest.raises(OutsideLimits): + adj.set_target_value(150) + + +def test_limited_reversed_limits(): + adj = DummyAdjustable(name="Test", ID="test_id") + adj.set_limits(low=100, high=10) + + adj.set_target_value(50).wait() + assert adj.get_current_value() == 50 diff --git a/tests/test_core_devices.py b/tests/test_core_devices.py new file mode 100644 index 000000000..ca9eddcd6 --- /dev/null +++ b/tests/test_core_devices.py @@ -0,0 +1,715 @@ +import pytest +import warnings +from types import SimpleNamespace + +from slic.core.device import Device, SimpleDevice +from slic.core.device.device import decide_z, read_z_from_channel, recursive_adjustables +from slic.core.device.filtered import by_type, by_name, filtered +from slic.core.device.auto import auto +from slic.core.adjustable import Adjustable +from slic.core.adjustable.dummyadjustable import DummyAdjustable + + +# Device + +@pytest.mark.parametrize("device_id,expected_name", [ + ("DEV001", "DEV001"), + ("MOTOR_X", "MOTOR_X"), + ("dev123", "dev123"), + ("X", "X"), + ("A_B_C_D", "A_B_C_D"), +]) +def test_device_init_name_defaults_to_id(device_id, expected_name): + dev = Device(ID=device_id) + assert dev.ID == device_id + assert dev.name == expected_name + assert dev.description is None + + +@pytest.mark.parametrize("device_id,custom_name,description", [ + ("DEV001", "Motor X", "X-axis motor controller"), + ("SENS002", "Temperature Sensor", "Beamline temp sensor"), + ("UND100", "Undulator", "Main undulator device"), + ("CAM01", "Camera 1", None), + ("STAGE", "XY Stage", ""), +]) +def test_device_init_with_custom_attributes(device_id, custom_name, description): + dev = Device(ID=device_id, name=custom_name, description=description) + assert dev.ID == device_id + assert dev.name == custom_name + assert dev.description == description + + +@pytest.mark.parametrize("device_id,z_value", [ + ("UND100", 100), + ("DEV999", 999), + ("MOTOR123", 123), + ("UND050:EXTRA", 50), + ("ABC250", 250), +]) +def test_device_z_undulator_from_id(device_id, z_value): + dev = Device(ID=device_id) + assert dev.z_undulator == z_value + + +@pytest.mark.parametrize("device_id,explicit_z", [ + ("ANYID", 42), + ("NOZINID", 0), + ("TEST", 999), + ("ABC", -10), + ("XYZ", 12345), +]) +def test_device_z_undulator_explicit(device_id, explicit_z): + dev = Device(ID=device_id, z_undulator=explicit_z) + assert dev.z_undulator == explicit_z + + +def test_device_with_adjustables(): + dev = Device(ID="TESTDEV", name="Test Device") + dev.motor_x = DummyAdjustable(ID="mx", name="Motor X") + dev.motor_y = DummyAdjustable(ID="my", name="Motor Y") + + assert hasattr(dev, "motor_x") + assert hasattr(dev, "motor_y") + assert isinstance(dev.motor_x, Adjustable) + assert isinstance(dev.motor_y, Adjustable) + + +def test_device_iteration(): + dev = Device(ID="TESTDEV") + dev.z_motor = DummyAdjustable(ID="z", name="Z") + dev.a_motor = DummyAdjustable(ID="a", name="A") + dev.m_motor = DummyAdjustable(ID="m", name="M") + + items = list(dev) + + assert len(items) == 3 + assert items[0].name == "A" + assert items[1].name == "M" + assert items[2].name == "Z" + + +def test_device_repr_with_adjustables(): + dev = Device(ID="TESTDEV", name="Test Device", description="A test device") + dev.motor = DummyAdjustable(ID="m1", name="Motor1", initial_value=42) + + repr_str = repr(dev) + + assert "test device" in repr_str.lower() + assert "motor" in repr_str + + +def test_device_repr_uses_description_then_name_then_id(): + dev1 = Device(ID="ID1", name="Name1", description="Description1") + repr1 = repr(dev1) + assert "Description1" in repr1 + + dev2 = Device(ID="ID2", name="Name2") + repr2 = repr(dev2) + assert "Name2" in repr2 + + dev3 = Device(ID="ID3") + repr3 = repr(dev3) + assert "ID3" in repr3 + + +def test_device_nested(): + parent = Device(ID="PARENT") + child = Device(ID="CHILD") + child.motor = DummyAdjustable(ID="m1", name="Motor") + parent.subsystem = child + + adjs = recursive_adjustables(parent) + assert "subsystem.motor" in adjs + assert adjs["subsystem.motor"].name == "Motor" + + +def test_device_recursive_detection(): + dev1 = Device(ID="DEV1") + dev2 = Device(ID="DEV2") + + dev1.sub = dev2 + dev2.parent = dev1 + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + adjs = recursive_adjustables(dev1) + + assert len(w) >= 1 + assert "Recursive Device" in str(w[0].message) + + +# Helper Functions + +@pytest.mark.parametrize("channel_id,expected_z", [ + ("UND100", 100), + ("DEV999", 999), + ("MOTOR123", 123), + ("UND001", 1), + ("TEST000", 0), + ("XYZ250", 250), +]) +def test_read_z_from_channel_valid(channel_id, expected_z): + assert read_z_from_channel(channel_id) == expected_z + + +@pytest.mark.parametrize("channel_id", [ + "NODIGITS", + "ABC", + "TEST", + "XYZ", + "MOTOR_X", + "AB1", + "X12", +]) +def test_read_z_from_channel_invalid(channel_id): + assert read_z_from_channel(channel_id) is None + + +@pytest.mark.parametrize("channel_id,colon_part", [ + ("MAIN:SUB:123", "123"), + ("CH:100", "100"), + ("A:B:C:456", "456"), + ("PREFIX:789:SUFFIX", "789"), +]) +def test_read_z_from_channel_with_colons(channel_id, colon_part): + z = read_z_from_channel(channel_id) + assert z is not None or colon_part[-3:].isdigit() + + +@pytest.mark.parametrize("channel_id,explicit_z,expected", [ + ("UND100", None, 100), + ("UND100", 200, 200), + ("NODIGITS", None, None), + ("NODIGITS", 42, 42), + ("TEST123", 0, 0), +]) +def test_decide_z(channel_id, explicit_z, expected): + assert decide_z(channel_id, explicit_z) == expected + + +# recursive_adjustables + +def test_recursive_adjustables_flat(): + dev = Device(ID="FLAT") + dev.adj1 = DummyAdjustable(ID="a1", name="Adj1") + dev.adj2 = DummyAdjustable(ID="a2", name="Adj2") + + adjs = recursive_adjustables(dev) + + assert len(adjs) == 2 + assert "adj1" in adjs + assert "adj2" in adjs + + +def test_recursive_adjustables_nested(): + parent = Device(ID="PARENT") + child1 = Device(ID="CHILD1") + child2 = Device(ID="CHILD2") + + parent.adj_top = DummyAdjustable(ID="top", name="Top") + child1.adj_c1 = DummyAdjustable(ID="c1", name="Child1") + child2.adj_c2 = DummyAdjustable(ID="c2", name="Child2") + + parent.subsys1 = child1 + parent.subsys2 = child2 + + adjs = recursive_adjustables(parent) + + assert len(adjs) == 3 + assert "adj_top" in adjs + assert "subsys1.adj_c1" in adjs + assert "subsys2.adj_c2" in adjs + + +def test_recursive_adjustables_deep_nesting(): + dev1 = Device(ID="L1") + dev2 = Device(ID="L2") + dev3 = Device(ID="L3") + + dev3.adj = DummyAdjustable(ID="deep", name="Deep") + dev2.level3 = dev3 + dev1.level2 = dev2 + + adjs = recursive_adjustables(dev1) + + assert "level2.level3.adj" in adjs + + +def test_recursive_adjustables_ignores_non_adjustables(): + dev = Device(ID="TEST") + dev.adjustable = DummyAdjustable(ID="adj", name="Adj") + dev.some_string = "not an adjustable" + dev.some_number = 42 + dev.some_list = [1, 2, 3] + + adjs = recursive_adjustables(dev) + + assert len(adjs) == 1 + assert "adjustable" in adjs + assert "some_string" not in adjs + assert "some_number" not in adjs + + +# SimpleDevice + +@pytest.mark.parametrize("device_id,kwargs", [ + ("SIMPLE1", {"x": 10, "y": 20}), + ("SIMPLE2", {"value": 42}), + ("SIMPLE3", {"a": 1, "b": 2, "c": 3}), + ("SIMPLE4", {}), + ("SIMPLE5", {"name_attr": "test", "count": 100}), +]) +def test_simpledevice_with_kwargs(device_id, kwargs): + dev = SimpleDevice(ID=device_id, **kwargs) + + assert dev.ID == device_id + for key, value in kwargs.items(): + assert hasattr(dev, key) + assert getattr(dev, key) == value + + +def test_simpledevice_inherits_device(): + dev = SimpleDevice(ID="SIMPLE", name="Simple Device") + dev.motor = DummyAdjustable(ID="m1", name="Motor") + + assert dev.name == "Simple Device", f"BUG: Expected 'Simple Device', got '{dev.name}'" + assert hasattr(dev, "motor") + + items = list(dev) + assert len(items) == 1 + + +def test_simpledevice_namespace_behavior(): + dev = SimpleDevice(ID="NS", x=10, y=20, z=30) + + assert dev.x == 10 + assert dev.y == 20 + assert dev.z == 30 + + +def test_simpledevice_mixed_adjustables_and_data(): + dev = SimpleDevice( + ID="MIXED", + name="Mixed Device", + calibration=1.5, + offset=10 + ) + dev.motor = DummyAdjustable(ID="m1", name="Motor") + + assert dev.calibration == 1.5 + assert dev.offset == 10 + + assert hasattr(dev, "motor") + assert isinstance(dev.motor, Adjustable) + + +def test_simpledevice_bug_name_and_description_ignored(): + dev = SimpleDevice( + ID="TESTID", + name="Custom Name Should Be Used", + description="Custom Description Should Be Used", + z_undulator=42 + ) + + assert dev.name == "Custom Name Should Be Used", \ + f"BUG: name parameter ignored! Expected 'Custom Name Should Be Used', got '{dev.name}'" + assert dev.description == "Custom Description Should Be Used", \ + f"BUG: description parameter ignored! Expected 'Custom Description Should Be Used', got '{dev.description}'" + assert dev.z_undulator == 42, \ + f"BUG: z_undulator parameter ignored! Expected 42, got {dev.z_undulator}" + + +def test_simpledevice_bug_cascade_effect(): + fake_globals = {"adj": DummyAdjustable(ID="a1", name="Adj")} + auto_dev = auto(fake_globals, "AUTO_ID", name="Auto Name", description="Auto Desc") + + assert auto_dev.name == "Auto Name", \ + f"BUG CASCADE in auto(): Expected 'Auto Name', got '{auto_dev.name}'" + assert auto_dev.description == "Auto Desc", \ + f"BUG CASCADE in auto(): Expected 'Auto Desc', got {auto_dev.description}" + + orig_dev = Device(ID="ORIG_ID", name="Original Name", description="Original Desc") + orig_dev.motor = DummyAdjustable(ID="m1", name="Motor") + filtered_dev = by_name(orig_dev, "motor") + + assert filtered_dev.name == "Original Name", \ + f"BUG CASCADE in filtered(): Expected 'Original Name', got '{filtered_dev.name}'" + assert filtered_dev.description == "Original Desc", \ + f"BUG CASCADE in filtered(): Expected 'Original Desc', got {filtered_dev.description}" + + +# Filtering + +def test_filtered_by_type_single(): + dev = Device(ID="TEST") + dev.dummy1 = DummyAdjustable(ID="d1", name="Dummy1") + dev.dummy2 = DummyAdjustable(ID="d2", name="Dummy2") + + filtered_dev = by_type(dev, DummyAdjustable) + + assert filtered_dev is not None + assert hasattr(filtered_dev, "dummy1") + assert hasattr(filtered_dev, "dummy2") + + +def test_filtered_by_type_excludes_other_types(): + from slic.core.adjustable.genericadjustable import GenericAdjustable + + dev = Device(ID="TEST") + dev.dummy = DummyAdjustable(ID="d1", name="Dummy") + dev.generic = GenericAdjustable( + ID="g1", + get=lambda: 0, + set=lambda x: None, + name="Generic" + ) + + filtered_dev = by_type(dev, DummyAdjustable) + + assert filtered_dev is not None + assert hasattr(filtered_dev, "dummy") + assert not hasattr(filtered_dev, "generic") + + +def test_filtered_by_type_multiple_types(): + from slic.core.adjustable.genericadjustable import GenericAdjustable + + dev = Device(ID="TEST") + dev.dummy = DummyAdjustable(ID="d1", name="Dummy") + dev.generic = GenericAdjustable( + ID="g1", + get=lambda: 0, + set=lambda x: None, + name="Generic" + ) + + filtered_dev = by_type(dev, (DummyAdjustable, GenericAdjustable)) + + assert filtered_dev is not None + assert hasattr(filtered_dev, "dummy") + assert hasattr(filtered_dev, "generic") + + +@pytest.mark.parametrize("pattern,expected_attrs", [ + ("motor", ["motor_x", "motor_y"]), + ("x", ["motor_x", "pos_x"]), + ("y", ["motor_y", "pos_y"]), + ("pos", ["pos_x", "pos_y"]), + ("sensor", ["temp_sensor"]), +]) +def test_filtered_by_name(pattern, expected_attrs): + dev = Device(ID="TEST") + dev.motor_x = DummyAdjustable(ID="mx", name="MotorX") + dev.motor_y = DummyAdjustable(ID="my", name="MotorY") + dev.pos_x = DummyAdjustable(ID="px", name="PosX") + dev.pos_y = DummyAdjustable(ID="py", name="PosY") + dev.temp_sensor = DummyAdjustable(ID="ts", name="TempSensor") + + filtered_dev = by_name(dev, pattern) + + assert filtered_dev is not None + for attr in expected_attrs: + assert hasattr(filtered_dev, attr) + + +def test_filtered_by_name_matches_adjustable_name(): + dev = Device(ID="TEST") + dev.attr1 = DummyAdjustable(ID="a1", name="SpecialMotor") + dev.attr2 = DummyAdjustable(ID="a2", name="NormalSensor") + + filtered_dev = by_name(dev, "Motor") + + assert filtered_dev is not None + assert hasattr(filtered_dev, "attr1") + assert not hasattr(filtered_dev, "attr2") + + +def test_filtered_custom_condition(): + dev = Device(ID="TEST") + dev.small = DummyAdjustable(ID="s1", name="Small", initial_value=5) + dev.large = DummyAdjustable(ID="l1", name="Large", initial_value=100) + + def condition(k, v): + return v.get_current_value() > 50 + + filtered_dev = filtered(dev, condition) + + assert filtered_dev is not None + assert not hasattr(filtered_dev, "small") + assert hasattr(filtered_dev, "large") + + +def test_filtered_nested_devices(): + parent = Device(ID="PARENT") + child = Device(ID="CHILD") + + parent.adj1 = DummyAdjustable(ID="a1", name="Motor1") + child.adj2 = DummyAdjustable(ID="a2", name="Motor2") + child.adj3 = DummyAdjustable(ID="a3", name="Sensor1") + parent.subsystem = child + + filtered_dev = by_name(parent, "Motor") + + assert filtered_dev is not None + assert hasattr(filtered_dev, "adj1") + assert hasattr(filtered_dev, "subsystem") + assert hasattr(filtered_dev.subsystem, "adj2") + assert not hasattr(filtered_dev.subsystem, "adj3") + + +def test_filtered_prunes_empty_subdevices(): + parent = Device(ID="PARENT") + child = Device(ID="CHILD") + + parent.motor = DummyAdjustable(ID="m1", name="Motor") + child.sensor = DummyAdjustable(ID="s1", name="Sensor") + parent.subsystem = child + + filtered_dev = by_name(parent, "Motor") + + assert filtered_dev is not None + assert hasattr(filtered_dev, "motor") + assert not hasattr(filtered_dev, "subsystem") + + +def test_filtered_returns_none_when_empty(): + dev = Device(ID="TEST") + dev.motor = DummyAdjustable(ID="m1", name="Motor") + + filtered_dev = by_name(dev, "NOMATCH") + + assert filtered_dev is None + + +def test_filtered_preserves_device_metadata(): + dev = Device(ID="ORIGINAL", name="Original Name", description="Original Desc") + dev.motor = DummyAdjustable(ID="m1", name="Motor") + + filtered_dev = by_name(dev, "motor") + + assert filtered_dev.ID == "ORIGINAL" + assert filtered_dev.name == "Original Name", \ + f"BUG: Expected 'Original Name', got '{filtered_dev.name}'" + assert filtered_dev.description == "Original Desc", \ + f"BUG: Expected 'Original Desc', got {filtered_dev.description}" + + +# Auto Device + +def test_auto_creates_device_from_globals(): + fake_globals = { + "motor_x": DummyAdjustable(ID="mx", name="MotorX"), + "motor_y": DummyAdjustable(ID="my", name="MotorY"), + "sensor": DummyAdjustable(ID="s1", name="Sensor"), + "_private": DummyAdjustable(ID="p1", name="Private"), + "not_adjustable": 42, + } + + dev = auto(fake_globals, "AUTO_DEV") + + assert dev.ID == "AUTO_DEV" + assert hasattr(dev, "motor_x") + assert hasattr(dev, "motor_y") + assert hasattr(dev, "sensor") + assert not hasattr(dev, "_private") + assert not hasattr(dev, "not_adjustable") + + +def test_auto_ignores_private_variables(): + fake_globals = { + "public": DummyAdjustable(ID="pub", name="Public"), + "_private": DummyAdjustable(ID="priv", name="Private"), + "__dunder": DummyAdjustable(ID="dun", name="Dunder"), + } + + dev = auto(fake_globals, "TEST") + + assert hasattr(dev, "public") + assert not hasattr(dev, "_private") + assert not hasattr(dev, "__dunder") + + +def test_auto_includes_devices(): + fake_globals = { + "adjustable": DummyAdjustable(ID="adj", name="Adj"), + "device": Device(ID="dev", name="Dev"), + } + + dev = auto(fake_globals, "AUTO") + + assert hasattr(dev, "adjustable") + assert hasattr(dev, "device") + assert isinstance(dev.adjustable, Adjustable) + assert isinstance(dev.device, Device) + + +def test_auto_with_kwargs(): + fake_globals = { + "motor": DummyAdjustable(ID="m1", name="Motor"), + } + + dev = auto(fake_globals, "AUTO", name="Custom Name", description="Custom Desc") + + assert dev.name == "Custom Name", \ + f"BUG: Expected 'Custom Name', got '{dev.name}'" + assert dev.description == "Custom Desc", \ + f"BUG: Expected 'Custom Desc', got {dev.description}" + assert hasattr(dev, "motor") + + +def test_auto_empty_globals(): + fake_globals = { + "some_string": "text", + "some_number": 42, + "_hidden": DummyAdjustable(ID="h1", name="Hidden"), + } + + dev = auto(fake_globals, "EMPTY") + + assert dev.ID == "EMPTY" + assert not hasattr(dev, "some_string") + assert not hasattr(dev, "some_number") + + +@pytest.mark.parametrize("prefix", ["_", "__", "___"]) +def test_auto_ignores_various_underscore_prefixes(prefix): + fake_globals = { + f"{prefix}var": DummyAdjustable(ID="v1", name="Var"), + "public": DummyAdjustable(ID="p1", name="Public"), + } + + dev = auto(fake_globals, "TEST") + + assert hasattr(dev, "public") + assert not hasattr(dev, f"{prefix}var") + + +# Integration +# Tests full hierarchy: root Device with sub-devices and adjustables +# Verifies recursive_adjustables finds all nested adjustables +# and by_name filters correctly through the hierarchy + +def test_full_device_hierarchy(): + root = Device(ID="ROOT", name="Root System") + subsys1 = Device(ID="SUB1", name="Subsystem 1") + subsys2 = Device(ID="SUB2", name="Subsystem 2") + + root.motor = DummyAdjustable(ID="m0", name="RootMotor") + subsys1.motor_x = DummyAdjustable(ID="mx", name="SubMotorX") + subsys1.sensor = DummyAdjustable(ID="s1", name="SubSensor") + subsys2.motor_y = DummyAdjustable(ID="my", name="SubMotorY") + + root.system1 = subsys1 + root.system2 = subsys2 + + all_adjs = recursive_adjustables(root) + assert len(all_adjs) == 4 + assert "motor" in all_adjs + assert "system1.motor_x" in all_adjs + assert "system1.sensor" in all_adjs + assert "system2.motor_y" in all_adjs + + motor_only = by_name(root, "motor") + assert motor_only is not None + motor_adjs = recursive_adjustables(motor_only) + assert len(motor_adjs) == 3 + + dummy_only = by_type(root, DummyAdjustable) + assert dummy_only is not None + + +# Tests that Device can contain adjustables, dicts, ints +# Only Adjustables are returned by recursive_adjustables + +def test_device_with_mixed_content(): + dev = Device(ID="MIXED") + dev.motor = DummyAdjustable(ID="m1", name="Motor") + dev.config = {"key": "value"} + dev.counter = 42 + sub = Device(ID="SUB") + sub.sensor = DummyAdjustable(ID="s1", name="Sensor") + dev.subsystem = sub + + adjs = recursive_adjustables(dev) + assert len(adjs) == 2 + assert "motor" in adjs + assert "subsystem.sensor" in adjs + assert "config" not in adjs + assert "counter" not in adjs + + +# Tests full chain: auto() creates SimpleDevice from globals +# then by_name filters adjustables by pattern + +def test_simpledevice_with_auto(): + fake_globals = { + "motor_x": DummyAdjustable(ID="mx", name="MotorX"), + "motor_y": DummyAdjustable(ID="my", name="MotorY"), + "sensor_temp": DummyAdjustable(ID="st", name="TempSensor"), + "sensor_pressure": DummyAdjustable(ID="sp", name="PressureSensor"), + } + + dev = auto(fake_globals, "AUTO", name="Auto Device") + + assert dev.name == "Auto Device", \ + f"BUG: Expected 'Auto Device', got '{dev.name}'" + assert hasattr(dev, "motor_x") + assert hasattr(dev, "sensor_temp") + + motors = by_name(dev, "motor") + assert motors is not None + assert hasattr(motors, "motor_x") + assert hasattr(motors, "motor_y") + assert not hasattr(motors, "sensor_temp") + + sensors = by_name(dev, "sensor") + assert sensors is not None + assert hasattr(sensors, "sensor_temp") + assert hasattr(sensors, "sensor_pressure") + assert not hasattr(sensors, "motor_x") + + +# Tests iteration on Device with sub-devices +# Returns all adjustables sorted by name + +def test_device_iteration_with_nested_structure(): + parent = Device(ID="PARENT") + child1 = Device(ID="CHILD1") + child2 = Device(ID="CHILD2") + + parent.adj_a = DummyAdjustable(ID="a", name="A") + child1.adj_b = DummyAdjustable(ID="b", name="B") + child2.adj_c = DummyAdjustable(ID="c", name="C") + + parent.sub1 = child1 + parent.sub2 = child2 + + items = list(parent) + assert len(items) == 3 + + names = [item.name for item in items] + assert names == ["A", "B", "C"] + + +@pytest.mark.parametrize("num_adjustables,num_devices", [ + (5, 0), + (0, 3), + (3, 2), + (10, 5), + (1, 1), +]) +def test_device_scalability(num_adjustables, num_devices): + dev = Device(ID="SCALE_TEST") + + for i in range(num_adjustables): + setattr(dev, f"adj_{i}", DummyAdjustable(ID=f"a{i}", name=f"Adj{i}")) + + for i in range(num_devices): + sub = Device(ID=f"SUB{i}") + sub.motor = DummyAdjustable(ID=f"m{i}", name=f"Motor{i}") + setattr(dev, f"sub_{i}", sub) + + adjs = recursive_adjustables(dev) + expected_total = num_adjustables + num_devices + assert len(adjs) == expected_total diff --git a/tests/test_runname.py b/tests/test_runname.py new file mode 100644 index 000000000..a6ff6a2ad --- /dev/null +++ b/tests/test_runname.py @@ -0,0 +1,276 @@ +import pytest +from pathlib import Path +from slic.core.scanner.runname import ( + extract_runnumber, + extract_runnumbers, + RunFilenameGenerator, + EVERYTHING +) + + +@pytest.fixture +def tmpdir_runs(tmp_path): + d = tmp_path / "runs" + d.mkdir() + return d + + +# extract_runnumber + +@pytest.mark.parametrize( + "fname,prefix,separator,expected", + [ + ("scan0004_test.json", "scan", "_", 4), + ("scan0001_alpha.txt", "scan", "_", 1), + ("run-0042_demo.csv", "run-", "_", 42), + ("data9999_final.txt", "data", "_", 9999), + ("scan5_test.json", "scan", "_", 5), + (str(Path("/tmp/data/scan0007_exp.json")), "scan", "_", 7), + (str(Path("/Users/yasmine_tligui/test_pv/runs/scan0123_test.json")), "scan", "_", 123), + ("scan0010_report.csv", "scan", "_", 10), + ("scan0300_analysis.yaml", "scan", "_", 300), + ("experimentscan0070_trial.txt", "experimentscan", "_", 70), + ("meas:001_test.txt", "meas:", "_", 1), + (str(Path("a/b/c/scan0042_demo.json")), "scan", "_", 42), + ("scan1234_extra_part.json", "scan", "_", 1234), + ("scan_0042_test.json", "scan_", "_", 42), + ], +) +def test_extract_runnumber_valid_cases(fname, prefix, separator, expected): + assert extract_runnumber(fname, prefix, separator) == expected + + +@pytest.mark.parametrize( + "fname,prefix,separator", + [ + ("scan00A3_test.json", "scan", "_"), + ("file_test.json", "scan", "_"), + ("scan_test.json", "scan", "_"), + ("run#_info.txt", "run#", "#"), + ("1234_scan_test.json", "scan", "_"), + ], +) +def test_extract_runnumber_invalid_cases(fname, prefix, separator): + with pytest.raises(ValueError): + extract_runnumber(fname, prefix, separator) + + +# extract_runnumbers + +@pytest.mark.parametrize( + "fnames,prefix,separator,expected", + [ + (["scan0001_a", "scan0002_b", "scan0030_c"], "scan", "_", [1, 2, 30]), + (["scan1_a", "scan02_b", "scan300_c"], "scan", "_", [1, 2, 300]), + (["run-0004_x", "run-0005_y", "run-0010_z"], "run-", "_", [4, 5, 10]), + (["data9_file", "data10_file", "data11_file"], "data", "_", [9, 10, 11]), + (["scan001--a", "scan002--b", "scan003--c"], "scan", "--", [1, 2, 3]), + ], +) +def test_extract_runnumbers_valid(fnames, prefix, separator, expected): + result = extract_runnumbers(fnames, prefix, separator) + assert result == expected + + +# RunFilenameGenerator + +@pytest.mark.parametrize( + "prefix,run_index,separator,name,suffix,expected", + [ + ("pfx", "7", "-", "abc", ".txt", "pfx7-abc"), + ("scan", "001", "_", "data", "_scan_info.json", "scan001_data"), + ("run", "42", "-", "test", ".meta", "run42-test"), + ("exp", "003", "--", "trial", ".log", "exp003--trial"), + ("data", "99", "::", "final", ".json", "data99::final"), + ("experiment_", "005", "_", "measure", ".csv", "experiment_005_measure"), + ], +) +def test_fill_filename_pattern(prefix, run_index, separator, name, suffix, expected): + gen = RunFilenameGenerator( + base_dir=".", + prefix=prefix, + separator=separator, + suffix=suffix + ) + formatted = gen._fill_filename_pattern(run_index, name) + assert formatted == expected + + +@pytest.mark.parametrize( + "prefix,n_digits,separator,suffix,expected", + [ + ("scan", 3, "_", "_scan_info.json", "scan" + "[0-9]"*3 + "_" + EVERYTHING + "_scan_info.json"), + ("run-", 4, "-", ".json", "run-" + "[0-9]"*4 + "-" + EVERYTHING + ".json"), + ("data", 2, "#", ".meta", "data" + "[0-9]"*2 + "#" + EVERYTHING + ".meta"), + ("exp", 1, ".", "_info.txt", "exp" + "[0-9]"*1 + "." + EVERYTHING + "_info.txt"), + ], +) +def test_pattern_exact_match(tmpdir_runs, prefix, n_digits, separator, suffix, expected): + gen = RunFilenameGenerator( + base_dir=tmpdir_runs, + prefix=prefix, + n_digits=n_digits, + separator=separator, + suffix=suffix, + ) + pattern = gen.pattern + assert pattern == expected + + +def test_get_existing_runnumbers_empty(monkeypatch, tmpdir_runs): + gen = RunFilenameGenerator(tmpdir_runs) + monkeypatch.setattr("slic.core.scanner.runname.glob_files", lambda b, p: []) + assert gen.get_existing_runnumbers() == [] + + +@pytest.mark.parametrize( + "filenames,prefix,separator,suffix,n_digits,expected_runnumbers", + [ + ( + ["scan0001_test_scan_info.json", "scan0002_demo_scan_info.json", "scan0005_exp_scan_info.json"], + "scan", "_", "_scan_info.json", 4, + [1, 2, 5] + ), + ( + ["run-0010_alpha_data.json", "run-0020_beta_data.json"], + "run-", "_", "_data.json", 4, + [10, 20] + ), + ( + ["data01_test.meta", "data02_test.meta", "data99_test.meta"], + "data", "_", ".meta", 2, + [1, 2, 99] + ), + ], +) +def test_get_existing_runnumbers_with_files(tmpdir_runs, filenames, prefix, separator, suffix, n_digits, expected_runnumbers): + for fname in filenames: + (tmpdir_runs / fname).write_text("test content") + + gen = RunFilenameGenerator( + base_dir=tmpdir_runs, + prefix=prefix, + separator=separator, + suffix=suffix, + n_digits=n_digits, + ) + runnums = gen.get_existing_runnumbers() + assert sorted(runnums) == sorted(expected_runnumbers) + + +def test_get_existing_runnumbers_mixed_files(tmpdir_runs): + files = [ + "scan0001_test_scan_info.json", + "scan0002_demo_scan_info.json", + "other_file.json", + "random_data.txt", + "scan_without_number.json", + ] + for fname in files: + (tmpdir_runs / fname).write_text("test") + + gen = RunFilenameGenerator(tmpdir_runs) + runnums = gen.get_existing_runnumbers() + assert sorted(runnums) == [1, 2] + + +def test_sequential_run_generation(tmpdir_runs): + gen = RunFilenameGenerator(tmpdir_runs) + + first = gen.get_next_run_filename("test") + assert first == "scan0000_test" + (tmpdir_runs / f"{first}_scan_info.json").write_text("data") + + second = gen.get_next_run_filename("demo") + assert second == "scan0001_demo" + (tmpdir_runs / f"{second}_scan_info.json").write_text("data") + + third = gen.get_next_run_filename("exp") + assert third == "scan0002_exp" + + +def test_get_next_run_filename_non_contiguous(tmpdir_runs): + files = ["scan0001_a_scan_info.json", "scan0005_b_scan_info.json", "scan0010_c_scan_info.json"] + for fname in files: + (tmpdir_runs / fname).write_text("test") + + gen = RunFilenameGenerator(tmpdir_runs) + next_file = gen.get_next_run_filename("test") + assert next_file == "scan0011_test" + + +@pytest.mark.parametrize( + "n_digits,expected_format", + [ + (1, "scan0_test"), + (2, "scan00_test"), + (3, "scan000_test"), + (5, "scan00000_test"), + (6, "scan000000_test"), + ], +) +def test_different_n_digits(tmpdir_runs, n_digits, expected_format): + gen = RunFilenameGenerator( + base_dir=tmpdir_runs, + n_digits=n_digits, + ) + next_file = gen.get_next_run_filename("test") + assert next_file == expected_format + + +@pytest.mark.parametrize( + "file_structure,search_base_dir,prefix,separator,suffix,n_digits,expected_next", + [ + ( + { + "experiments": ["scan0001_test_scan_info.json", "scan0002_test_scan_info.json"], + "other_folder": ["scan0003_test_scan_info.json", "scan0004_test_scan_info.json"] + }, + "experiments", + "scan", "_", "_scan_info.json", 4, + "scan0003_test" + ), + ( + { + "project/data": ["scan0001_test_scan_info.json"], + "project/data/raw": ["scan0002_test_scan_info.json"], + "project/backup": ["scan0003_test_scan_info.json"], + "other_project": ["scan0004_test_scan_info.json"] + }, + "project/data", + "scan", "_", "_scan_info.json", 4, + "scan0002_test" + ), + ( + { + "empty_dir": [], + "full_dir": ["scan0001_test_scan_info.json", "scan0002_test_scan_info.json"], + }, + "empty_dir", + "scan", "_", "_scan_info.json", 4, + "scan0000_test" + ), + ], +) +def test_get_next_run_filename(tmpdir, file_structure, search_base_dir, prefix, separator, suffix, n_digits, expected_next): + for dir_path, filenames in file_structure.items(): + current_dir = tmpdir + for part in dir_path.split("/"): + current_dir = current_dir / part + if not current_dir.exists(): + current_dir.mkdir() + + for filename in filenames: + (current_dir / filename).write("test") + + actual_base_dir = tmpdir / search_base_dir + + gen = RunFilenameGenerator( + base_dir=str(actual_base_dir), + prefix=prefix, + n_digits=n_digits, + separator=separator, + suffix=suffix, + ) + next_file = gen.get_next_run_filename("test") + assert next_file == expected_next diff --git a/tests/test_scanbackend.py b/tests/test_scanbackend.py new file mode 100644 index 000000000..91c4e4c49 --- /dev/null +++ b/tests/test_scanbackend.py @@ -0,0 +1,1823 @@ +import pytest +from pathlib import Path +import os + +from slic.core.scanner.scanbackend import ( + ScanBackend, + is_sfdaq, is_only_sfdaq, + print_all_current_values, get_all_current_values, + set_all_target_values_and_wait, set_all_target_values, + wait_for_all, stop_all, +) + +from slic.core.acquisition import SFAcquisition +from slic.core.acquisition.fakeacquisition import FakeAcquisition as _BaseFakeAcquisition +from slic.core.adjustable.dummyadjustable import DummyAdjustable +from slic.core.task import DAQTask + + +# FakeAcquisition + +class FakeAcquisition(_BaseFakeAcquisition): + + def __init__(self, instrument, pgroup): + super().__init__(instrument, pgroup) + self.call_count = 0 + self.acquire_called = False + self.last_filename = None + self.last_data_base_dir = None + self.last_channels = None + self.last_n_pulses = None + self.last_wait = None + + def acquire(self, filename=None, data_base_dir=None, detectors=None, channels=None, pvs=None, scan_info=None, n_pulses=100, n_repeat=1, is_scan_step=False, wait=True): + self.call_count += 1 + self.acquire_called = True + self.last_filename = filename + self.last_data_base_dir = data_base_dir + self.last_channels = channels + self.last_n_pulses = n_pulses + self.last_wait = wait + return super().acquire(filename, data_base_dir, detectors, channels, pvs, scan_info, n_pulses, n_repeat, is_scan_step, wait) + + def reset(self): + self.call_count = 0 + self.acquire_called = False + self.last_filename = None + self.last_data_base_dir = None + self.last_channels = None + self.last_n_pulses = None + self.last_wait = None + + +# DummyCondition + +class DummyCondition: + def __init__(self, repeats=0): + self.repeats = repeats + self._stopped = False + def wants_repeat(self): + self.repeats -= 1 + return self.repeats >= 0 + def stop(self): + self._stopped = True + + +# DummySensor + +class DummySensor: + counter = 0 + def __init__(self, name=None): + DummySensor.counter += 1 + self.name = name or f"sensor_{DummySensor.counter}" + self.started = False + self.stopped = False + self._cache = {} + def start(self): + self.started = True + def stop(self): + self.stopped = True + def get(self): + return 3.14 + + +# DummyRemotePlot + +class DummyRemotePlot: + def __init__(self, fail=False): + self.fail = fail + self.created = False + self.appended = False + self.last_data = None + self.last_filename = None + + def new_plot(self, filename, cfg): + self.created = True + if self.fail: + raise ConnectionRefusedError + + def append_data(self, filename, data): + self.appended = True + self.last_data = data + self.last_filename = filename + if self.fail: + raise ConnectionRefusedError + + +# NonSFDAQAcquisition + +class NonSFDAQAcquisition: + + def __init__(self, name="NonSFDAQ", default_dir=None): + self.name = name + self.default_dir = default_dir + self.call_count = 0 + self.last_filename = None + self.last_n_pulses = None + self.filenames = [] + + def acquire(self, filename=None, n_pulses=100, **_kwargs): + self.call_count += 1 + self.last_filename = filename + self.last_n_pulses = n_pulses + + def fake_acquire_func(): + return [f"{filename}_nonsfdaq_{self.call_count}.h5"] + + task = DAQTask(fake_acquire_func) + self.filenames.extend([f"{filename}_nonsfdaq_{self.call_count}.h5"]) + return task + + def __repr__(self): + return f"NonSFDAQAcquisition({self.name})" + + +# is_sfdaq + +def test_is_sfdaq_and_only_sfdaq(): + class MockConfig: + pgroup = None + + class MockClient: + config = MockConfig() + + class MockSFAcquisition(SFAcquisition): + def __init__(self, instrument, pgroup): + self.client = MockClient() + self.instrument = instrument + self._pgroup = pgroup + + s1, s2 = MockSFAcquisition("test_instrument", "test_pgroup"), MockSFAcquisition("test_instrument", "test_pgroup") + f1 = FakeAcquisition("test_instrument", "test_pgroup") + random_obj = object() + + assert is_sfdaq(s1) + assert is_sfdaq(f1) + assert not is_sfdaq(random_obj) + + assert is_only_sfdaq([s1, s2]) + assert is_only_sfdaq([f1, s1]) + assert not is_only_sfdaq([s1, random_obj]) + + +# get_filename + +def test_get_filename(tmp_path): + adjs = [DummyAdjustable(ID="AX", name="AX")] + acqs = [FakeAcquisition("test_instrument", "test_pgroup")] + + sb1 = ScanBackend( + adjs, [[1]], acqs, "scanfile", + detectors=[], channels=[], pvs=[], + n_pulses=1, data_base_dir="data", scan_info_dir=tmp_path, + make_scan_sub_dir=False, condition=None, + return_to_initial_values=True, n_repeat=1, + sensor=None, remote_plot=None + ) + f1 = sb1.get_filename(7) + assert f1.endswith("scanfile_step0007") + assert os.path.basename(f1).startswith("scanfile") + + sb2 = ScanBackend( + adjs, [[1]], acqs, "scanfile", + detectors=[], channels=[], pvs=[], + n_pulses=1, data_base_dir="data", scan_info_dir=tmp_path, + make_scan_sub_dir=True, condition=None, + return_to_initial_values=True, n_repeat=1, + sensor=None, remote_plot=None + ) + f2 = sb2.get_filename(3) + expected_sub = os.path.join("scanfile", "scanfile_step0003") + assert f2.endswith(expected_sub) + + sb3 = ScanBackend( + adjs, [[1]], acqs, "/tmp/path/to/custom_name", + detectors=[], channels=[], pvs=[], + n_pulses=1, data_base_dir="data", scan_info_dir=tmp_path, + make_scan_sub_dir=True, condition=None, + return_to_initial_values=True, n_repeat=1, + sensor=None, remote_plot=None + ) + f3 = sb3.get_filename(1) + assert f3.endswith(os.path.join("custom_name", "custom_name_step0001")) + + +# create_output_dirs + +def test_create_output_dirs(tmp_path): + class MockConfig: + pgroup = None + + class MockClient: + config = MockConfig() + + class MockSFAcquisition(SFAcquisition): + def __init__(self, instrument, pgroup): + self.client = MockClient() + self.instrument = instrument + self._pgroup = pgroup + + adjs = [DummyAdjustable(ID="A1")] + + sfdaq_acq = MockSFAcquisition("test_instrument", "test_pgroup") + sb = ScanBackend( + adjs, [[1]], [sfdaq_acq], + filename="scan_sfdaq", + detectors=[], channels=[], pvs=[], + n_pulses=1, data_base_dir="data", scan_info_dir=tmp_path, + make_scan_sub_dir=True, condition=None, + return_to_initial_values=True, n_repeat=1, + sensor=None, remote_plot=None + ) + + sb.create_output_dirs() + for root, dirs, files in os.walk(tmp_path): + assert not dirs + + fake_acq = FakeAcquisition("test_instrument", "test_pgroup") + fake_acq.default_dir = str(tmp_path / "fake_default") + sb = ScanBackend( + adjs, [[1]], [fake_acq], + filename="scan_fake", + detectors=[], channels=[], pvs=[], + n_pulses=1, data_base_dir="data", scan_info_dir=tmp_path, + make_scan_sub_dir=False, condition=None, + return_to_initial_values=True, n_repeat=1, + sensor=None, remote_plot=None + ) + + sb.create_output_dirs() + expected_data_dir = os.path.join(fake_acq.default_dir, sb.data_base_dir) + assert not os.path.exists(expected_data_dir) + + +# store_initial_values + +def test_store_and_change_initial_values_restores_correctly(tmp_path): + adjs = [ + DummyAdjustable(ID="A", initial_value=9, process_time=0), + DummyAdjustable(ID="B", initial_value=8, process_time=0) + ] + + sb = ScanBackend( + adjs, [[1]], [FakeAcquisition("test_instrument", "test_pgroup")], + filename="fn", detectors=[], channels=[], pvs=[], + n_pulses=1, data_base_dir="data", scan_info_dir=tmp_path, + make_scan_sub_dir=True, condition=None, + return_to_initial_values=True, n_repeat=1, + sensor=None, remote_plot=None + ) + + sb.store_initial_values() + initial_values = [a.get_current_value() for a in adjs] + assert initial_values == [9, 8] + assert sb.initial_values == [9, 8] + + for a, new_val in zip(adjs, [100, 200]): + a.set_target_value(new_val) + + changed_values = [a.get_current_value() for a in adjs] + assert changed_values == [100, 200] + + sb.change_to_initial_values() + + restored_values = [a.get_current_value() for a in adjs] + assert restored_values == [9, 8] + + +# acquire_all + +def test_acquire_all_with_fake_acquisitions(tmp_path): + adjs = [DummyAdjustable(name="A", ID="A")] + + fake_acq1 = FakeAcquisition("test_instrument", "test_pgroup") + fake_acq2 = FakeAcquisition("test_instrument", "test_pgroup") + acqs = [fake_acq1, fake_acq2] + + sb = ScanBackend( + adjs, [[1]], acqs, "test_scan", + ["detector1"], ["bs_channel1"], ["pv1"], 3, + "data", tmp_path, True, None, True, 1, None, None + ) + + filenames = sb.acquire_all("test_filename") + + assert hasattr(sb, 'current_tasks') + assert len(sb.current_tasks) == 2 + + assert all(t.status == "done" for t in sb.current_tasks) + + for acq in acqs: + assert acq.acquire_called + assert "test_scan" in acq.last_filename or acq.last_filename == "test_scan" + assert acq.last_data_base_dir == "data" + assert acq.last_channels == ["bs_channel1"] + assert acq.last_n_pulses == 3 + assert acq.last_wait == False + + assert len(filenames) >= 2 + assert all(isinstance(fname, str) for fname in filenames) + assert all(len(fname) > 0 for fname in filenames) + + sb.stop() + assert not sb.running + for t in sb.current_tasks: + assert t.status == "done" + + +# do_step + +def test_do_step_with_fake_acquisitions(tmp_path): + adjs = [DummyAdjustable(name="motor1", ID="M1"), DummyAdjustable(name="motor2", ID="M2")] + + fake_acq = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1, 2], [3, 4]], [fake_acq], "test_scan", + [], ["bs_channel"], [], 1, + "data", tmp_path, True, None, True, 1, None, None + ) + + step_values = [1, 3] + n_step = 0 + + sb.do_step(n_step, step_values) + + current_values = [adj.get_current_value() for adj in adjs] + assert current_values == step_values + + assert len(sb.scan_info_sfdaq.values) > 0 + assert len(sb.scan_info_sfdaq.readbacks) > 0 + + assert fake_acq.last_filename == "test_scan" + + assert len(sb.scan_info.values) > 0 + assert len(sb.scan_info.readbacks) > 0 + + +def test_do_step_with_sensor_and_remote_plot(tmp_path): + adjs = [DummyAdjustable(name="motor", ID="M1")] + + fake_acq = FakeAcquisition("test_instrument", "test_pgroup") + dummy_sensor = DummySensor() + dummy_remote_plot = DummyRemotePlot() + + sb = ScanBackend( + adjs, [[1, 2, 3]], [fake_acq], "test_scan", + [], ["bs_channel"], [], 1, + "data", tmp_path, True, None, True, 1, dummy_sensor, dummy_remote_plot + ) + + step_values = [2] + n_step = 1 + + sb.do_step(n_step, step_values) + + assert dummy_sensor.started + assert dummy_sensor.stopped + + assert dummy_remote_plot.appended + + x_value = adjs[0].get_current_value() + y_value = dummy_sensor.get() + expected_data = (float(x_value), float(y_value)) + assert dummy_remote_plot.last_data == expected_data + + +def test_do_step_multiple_steps(tmp_path): + adjs = [DummyAdjustable(name="motor", ID="M1")] + + fake_acq = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1, 2, 3]], [fake_acq], "test_scan", + [], ["bs_channel"], [], 1, + "data", tmp_path, True, None, True, 1, None, None + ) + + step_sequences = [ + (0, [1]), + (1, [2]), + (2, [3]) + ] + + for n_step, step_values in step_sequences: + fake_acq.reset() + + sb.do_step(n_step, step_values) + + assert adjs[0].get_current_value() == step_values[0] + + assert "test_scan" in fake_acq.last_filename + + assert len(sb.scan_info.values) == n_step + 1 + + +# do_checked_step + +def test_do_checked_step_with_condition_repeats(tmp_path): + adjs = [DummyAdjustable(name="motor", ID="M1")] + fake_acq = FakeAcquisition("test_instrument", "test_pgroup") + + condition = DummyCondition(repeats=2) + + sb = ScanBackend( + adjs, [[1, 2]], [fake_acq], "test_scan", + [], ["bs_channel"], [], 1, + "data", tmp_path, True, condition, True, 1, None, None + ) + + step_values = [1] + n_step = 0 + + do_step_call_count = 0 + original_do_step = sb.do_step + + def mock_do_step(*args, **kwargs): + nonlocal do_step_call_count + do_step_call_count += 1 + return original_do_step(*args, **kwargs) + + sb.do_step = mock_do_step + sb.running = True + + sb.do_checked_step(n_step, step_values) + + assert do_step_call_count == 2 + + assert condition.repeats == -1 + + +# _make_summary + +def test_make_summary_and_repr(tmp_path): + adjs = [DummyAdjustable(name="A", ID="A")] + sb = ScanBackend( + adjs, [[1, 2]], [FakeAcquisition("test_instrument", "test_pgroup")], + "fn", [], [], [], 2, + "data", tmp_path, True, None, True, 2, None, None + ) + s = sb._make_summary() + assert "record" in s and "pulse" in s + assert isinstance(repr(sb), str) + + +def test_make_summary_single_repeat(tmp_path): + adjs = [DummyAdjustable(name="motor1", ID="M1"), DummyAdjustable(name="motor2", ID="M2")] + fake_acq = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1, 2], [3, 4]], [fake_acq], "test_scan", + [], ["bs_channel"], [], 5, + "data", tmp_path, True, None, True, 1, + None, None + ) + + summary = sb._make_summary() + + assert "perform the following scan" in summary + + assert "motor1" in summary + assert "motor2" in summary + + assert "5 pulses" in summary + + assert "test_scan" in summary + + assert "FakeAcquisition" in summary + + +def test_make_summary_multiple_repeats(tmp_path): + adjs = [DummyAdjustable(name="motor", ID="M1")] + fake_acq = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1]], [fake_acq], "multi_scan", + [], ["bs_channel"], [], 1, + "data", tmp_path, True, None, True, 3, + None, None + ) + + summary = sb._make_summary() + + assert "repeat the following scan 3 times" in summary + + assert "1 pulse" in summary + + assert "multi_scan" in summary + + +# scan_loop + +def test_scan_loop_fake_only(tmp_path, capsys): + adjs = [ + DummyAdjustable(name="A", ID="A", initial_value=0), + DummyAdjustable(name="B", ID="B", initial_value=0), + ] + + fake = FakeAcquisition("test_instrument", "test_pgroup") + + values = [[1, 10], [2, 20], [3, 30]] + + sb = ScanBackend( + adjs, values, [fake], "scan1", + [], ["ch"], [], 1, + "data", tmp_path, False, + condition=None, return_to_initial_values=True, n_repeat=1, + sensor=None, remote_plot=None + ) + + sb.running = True + sb.scan_loop() + + out = capsys.readouterr().out + + assert "Scan step 1 of 3" in out + assert "Scan step 2 of 3" in out + assert "Scan step 3 of 3" in out + assert "All scan steps done" in out + + assert fake.call_count == 3 + + assert adjs[0].get_current_value() == 3 + assert adjs[1].get_current_value() == 30 + + +# repeated_scan_loop + +def test_repeated_scan_loop_fake_only(tmp_path, capsys): + adjs = [DummyAdjustable(name="A", ID="A")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + values = [[1], [2]] + + sb = ScanBackend( + adjs, values, [fake], "rscan", + [], ["ch"], [], 1, + "data", tmp_path, False, + condition=None, + return_to_initial_values=True, n_repeat=3, + sensor=None, remote_plot=None + ) + + sb.running = True + sb.repeated_scan_loop() + + out = capsys.readouterr().out + + assert "Repetition 1 of 3" in out + assert "Repetition 2 of 3" in out + assert "Repetition 3 of 3" in out + + assert fake.call_count == 6 + + assert sb.filename == "rscan" + + +# Full scan end-to-end +# Tests complete scan execution with multiple adjustables and steps +# Verifies all values, readbacks, filenames, and return to initial + +def test_full_multidimensional_scan_end_to_end(tmp_path): + d = 4 + n_steps = 5 + + adjs = [ + DummyAdjustable(name=f"M{i}", ID=f"ID{i}", initial_value=0, process_time=0) + for i in range(d) + ] + + values = [ + list(range(t, t + d)) + for t in range(1, n_steps + 1) + ] + + fake = FakeAcquisition("test_instrument", "test_pgroup") + fake.default_dir = str(tmp_path / "fake_data") + + sb = ScanBackend( + adjs, values, [fake], + filename="multidim_test", + detectors=[], channels=["ch"], pvs=[], + n_pulses=3, + data_base_dir="data", + scan_info_dir=tmp_path, + make_scan_sub_dir=True, + condition=None, + return_to_initial_values=True, + n_repeat=1, + sensor=None, + remote_plot=None + ) + + sb.run() + + assert fake.call_count == n_steps + + assert [a.get_current_value() for a in adjs] == [0] * d + + assert len(sb.scan_info.values) == n_steps + assert len(sb.scan_info_sfdaq.values) == n_steps + + for i in range(n_steps): + step_values = sb.scan_info.values[i] + assert step_values == values[i] + + base = sb.filename + filebase = os.path.basename(base) + for i in range(n_steps): + expected = os.path.join(base, filebase + f"_step{i:04d}") + assert expected.endswith(f"{filebase}_step{i:04d}") + + data_root = tmp_path / "fake_data" / "data" + assert not data_root.exists() + + expected_subfolder = data_root / "multidim_test" + assert not expected_subfolder.exists() + + assert all(t.status == "done" for t in sb.current_tasks) + + all_files = [] + for t in sb.current_tasks: + assert len(t.filenames) > 0 + all_files += t.filenames + + assert len(all_files) >= 1 + + for i in range(n_steps): + target = values[i] + readback = sb.scan_info_sfdaq.readbacks[i] + assert readback == target + + assert sb.scan_info.values[0] == values[0] + assert sb.scan_info.values[-1] == values[-1] + + +def test_scanND_relative_positions_only(tmp_path): + d = 5 + + initial_values = [10 * (i + 1) for i in range(d)] + adjustables = [ + DummyAdjustable(name=f"M{i}", ID=f"ID{i}", initial_value=initial_values[i], process_time=0) + for i in range(d) + ] + + positions_per_dim = [list(range(-1, 2))] * d + + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjustables, + values=positions_per_dim, + acquisitions=[fake], + filename="relND", + detectors=[], channels=["ch"], pvs=[], + n_pulses=1, + data_base_dir="data", + scan_info_dir=tmp_path, + make_scan_sub_dir=False, + condition=None, + return_to_initial_values=True, + n_repeat=1, + sensor=None, remote_plot=None + ) + + offset_values = [ + [p + initial_values[i] for p in [-1, 0, 1]] + for i in range(d) + ] + + sb.values = offset_values + + for i in range(d): + expected = [ + initial_values[i] - 1, + initial_values[i], + initial_values[i] + 1, + ] + assert sb.values[i] == expected + + assert len(sb.values) == d + + assert all(len(axis) == 3 for axis in sb.values) + + +# Utility functions + +def test_print_current_values_displays_correct_output(capsys): + class Obj: + def __init__(self, adjustables): + self.adjustables = adjustables + + def print_current_values(self): + print_all_current_values(self.adjustables) + + adjs = [ + DummyAdjustable(ID="A1", name="MotorA", initial_value=10), + DummyAdjustable(ID="B2", name="MotorB", initial_value=20), + ] + + obj = Obj(adjs) + obj.print_current_values() + + captured = capsys.readouterr().out + + assert "Current values" in captured + assert "A1" in captured and "B2" in captured + assert "10" in captured and "20" in captured + + +def test_get_all_current_values_returns_correct_list(): + adjs = [ + DummyAdjustable(ID="M1", name="M1", initial_value=5), + DummyAdjustable(ID="M2", name="M2", initial_value=15), + ] + values = get_all_current_values(adjs) + assert values == [5, 15] + + +def test_set_all_target_values_and_wait_full_chain(): + adjs = [ + DummyAdjustable(ID="1", initial_value=0, process_time=0), + DummyAdjustable(ID="2", initial_value=0, process_time=0), + ] + + set_all_target_values_and_wait(adjs, [10, 20]) + + values = [a.get_current_value() for a in adjs] + assert values == [10, 20] + + tasks = set_all_target_values(adjs, [15, 25]) + assert all(hasattr(t, "wait") for t in tasks) + + wait_for_all(tasks) + + values = [a.get_current_value() for a in adjs] + assert values == [15, 25] + + +def test_wait_for_all_calls_wait_on_all_tasks(monkeypatch): + called = [] + + class DummyTask: + def __init__(self, name): + self.name = name + def wait(self): + called.append(self.name) + + tasks = [DummyTask("t1"), DummyTask("t2"), DummyTask("t3")] + + wait_for_all(tasks) + + assert called == ["t1", "t2", "t3"] + + +def test_stop_all_calls_stop_and_handles_exceptions(capsys): + called = [] + + class WorkingTask: + def __init__(self, name): + self.name = name + self.stopped = False + def stop(self): + self.stopped = True + called.append(self.name) + + class FailingTask: + def stop(self): + raise RuntimeError("boom") + + tasks = [WorkingTask("T1"), FailingTask(), WorkingTask("T2")] + + stop_all(tasks) + + assert all(t.stopped for t in tasks if isinstance(t, WorkingTask)) + + assert called == ["T1", "T2"] + + out = capsys.readouterr().out + assert "Stopping caused" in out + assert "boom" in out + + +# Edge cases + +def test_run_with_exception_handling(tmp_path, capsys): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1], [2]], [fake], "error_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, False, 1, None, None + ) + + original_do_step = sb.do_step + def failing_do_step(*args, **kwargs): + if args[0] == 1: + raise ValueError("Intentional test error") + return original_do_step(*args, **kwargs) + + sb.do_step = failing_do_step + + sb.run() + + out = capsys.readouterr().out + assert "Stopping because of:" in out + assert "Intentional test error" in out + assert "Stopped current DAQ tasks:" in out + + assert not sb.running + + +def test_scan_interrupted_midway(tmp_path, capsys): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1], [2], [3], [4], [5]], [fake], "interrupted_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + original_do_step = sb.do_step + def stopping_do_step(*args, **kwargs): + result = original_do_step(*args, **kwargs) + if args[0] == 1: + sb.running = False + return result + + sb.do_step = stopping_do_step + + sb.run() + + out = capsys.readouterr().out + assert "Stopped during scan step" in out + assert "of 5" in out + + assert len(sb.scan_info.values) == 2 + + +def test_remote_plot_connection_refused_on_new_plot(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + sensor = DummySensor() + remote_plot = DummyRemotePlot(fail=True) + + sb = ScanBackend( + adjs, [[1]], [fake], "plot_fail_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, sensor, remote_plot + ) + + sb.run() + + assert remote_plot.created + assert remote_plot.appended + + +def test_remote_plot_connection_refused_on_append_data(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + sensor = DummySensor() + + class SelectiveFailPlot(DummyRemotePlot): + def __init__(self): + super().__init__(fail=False) + self.append_fail = True + + def append_data(self, filename, data): + self.appended = True + self.last_data = data + self.last_filename = filename + if self.append_fail: + raise ConnectionRefusedError("Connection refused") + + remote_plot = SelectiveFailPlot() + + sb = ScanBackend( + adjs, [[1]], [fake], "append_fail_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, sensor, remote_plot + ) + + sb.run() + + assert remote_plot.created + assert remote_plot.appended + + +def test_step_info_callable(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + call_count = [0] + def step_info_func(): + call_count[0] += 1 + return {"call": call_count[0], "timestamp": "2024-01-01"} + + sb = ScanBackend( + adjs, [[1], [2]], [fake], "callable_info_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + sb.running = True + sb.scan_loop(step_info=step_info_func) + + assert call_count[0] == 4 + + assert len(sb.scan_info.info) == 2 + assert sb.scan_info.info[0] == {"call": 2, "timestamp": "2024-01-01"} + assert sb.scan_info.info[1] == {"call": 4, "timestamp": "2024-01-01"} + + +def test_repeated_scan_with_very_large_n_repeat(tmp_path, capsys): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1], [2]], [fake], "many_repeat_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 10, None, None + ) + + sb.run() + + assert fake.call_count == 20 + + out = capsys.readouterr().out + assert "Repetition 1 of 10" in out + assert "Repetition 10 of 10" in out + + +def test_repeated_scan_loop_stops_when_running_false(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1]], [fake], "stop_repeat_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 100, None, None + ) + + original_scan_loop = sb.scan_loop + rep_count = [0] + def counting_scan_loop(*args, **kwargs): + result = original_scan_loop(*args, **kwargs) + rep_count[0] += 1 + if rep_count[0] >= 2: + sb.running = False + return result + + sb.scan_loop = counting_scan_loop + + sb.running = True + sb.repeated_scan_loop() + + assert rep_count[0] == 2 + assert fake.call_count == 2 + + +def test_sfdaq_with_spreadsheet_logging(tmp_path): + adjs = [DummyAdjustable(name="M1", ID="M1"), DummyAdjustable(name="M2", ID="M2")] + + logged_data = [] + class MockSpreadsheet: + def add(self, run_number, filename, n_pulses, scanned_adjs, scan_values): + logged_data.append({ + "run_number": run_number, + "filename": filename, + "n_pulses": n_pulses, + "scanned_adjs": scanned_adjs, + "scan_values": scan_values + }) + + fake = FakeAcquisition("test_instrument", "test_pgroup") + fake.spreadsheet = MockSpreadsheet() + + sb = ScanBackend( + adjs, [[1, 2], [3, 4]], [fake], "spreadsheet_scan", + [], ["ch"], [], 5, + "data", tmp_path, False, None, True, 1, None, None + ) + + sb.scan_loop() + + assert len(logged_data) == 1 + log = logged_data[0] + assert log["run_number"] == 1 + assert log["filename"] == "spreadsheet_scan" + assert log["n_pulses"] == 5 + assert log["scanned_adjs"] == adjs + assert log["scan_values"] == [[1, 2], [3, 4]] + + +def test_scan_loop_with_single_value(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[42]], [fake], "single_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + sb.running = True + sb.scan_loop() + + assert len(sb.scan_info.values) == 1 + assert sb.scan_info.values[0] == [42] + assert fake.call_count == 1 + + +def test_scan_loop_advances_run_number_per_acquisition(tmp_path, capsys): + adjs = [DummyAdjustable(name="M", ID="M")] + fake1 = FakeAcquisition("test_instrument", "test_pgroup") + fake2 = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1]], [fake1, fake2], "multi_acq_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + sb.scan_loop() + + out = capsys.readouterr().out + assert "Advanced run number to 1 for" in out + assert out.count("Advanced run number to") == 2 + + +def test_return_to_initial_values_false(tmp_path): + adjs = [DummyAdjustable(ID="M", initial_value=100, process_time=0)] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1], [2], [3]], [fake], "no_return_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, False, 1, None, None + ) + + sb.run() + + assert adjs[0].get_current_value() == 3 + assert adjs[0].get_current_value() != 100 + + +# N-dimensional grid scans + +def test_true_2D_grid_scan(tmp_path): + N = 2 + adjs = [DummyAdjustable(name=f"Adj{i}", ID=f"Adj{i}") for i in range(N)] + + x_vals = [0, 1, 2] + y_vals = [10, 20] + + values = [[x, y] for x in x_vals for y in y_vals] + + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, values, [fake], "grid_2d_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + sb.running = True + sb.scan_loop() + + assert len(sb.scan_info.values) == 6 + assert fake.call_count == 6 + + assert sb.scan_info.values[0] == [0, 10] + assert sb.scan_info.values[1] == [0, 20] + assert sb.scan_info.values[2] == [1, 10] + assert sb.scan_info.values[3] == [1, 20] + assert sb.scan_info.values[4] == [2, 10] + assert sb.scan_info.values[5] == [2, 20] + + +def test_true_3D_grid_scan(tmp_path): + N = 3 + adjs = [DummyAdjustable(name=f"Adj{i}", ID=f"Adj{i}") for i in range(N)] + + x_vals = [0, 1] + y_vals = [10, 20] + z_vals = [100, 200, 300] + + values = [[x, y, z] for x in x_vals for y in y_vals for z in z_vals] + + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, values, [fake], "grid_3d_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + sb.running = True + sb.scan_loop() + + assert len(sb.scan_info.values) == 12 + assert fake.call_count == 12 + + assert sb.scan_info.values[0] == [0, 10, 100] + assert sb.scan_info.values[1] == [0, 10, 200] + assert sb.scan_info.values[2] == [0, 10, 300] + assert sb.scan_info.values[3] == [0, 20, 100] + assert sb.scan_info.values[10] == [1, 20, 200] + assert sb.scan_info.values[11] == [1, 20, 300] + + +def test_true_4D_grid_scan(tmp_path): + N = 4 + adjs = [DummyAdjustable(name=f"Adj{i}", ID=f"Adj{i}") for i in range(N)] + + dim_vals = [[0, 1], [10, 11], [100, 101], [1000, 1001]] + + values = [] + for v0 in dim_vals[0]: + for v1 in dim_vals[1]: + for v2 in dim_vals[2]: + for v3 in dim_vals[3]: + values.append([v0, v1, v2, v3]) + + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, values, [fake], "grid_4d_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + sb.running = True + sb.scan_loop() + + assert len(sb.scan_info.values) == 16 + assert fake.call_count == 16 + + assert sb.scan_info.values[0] == [0, 10, 100, 1000] + assert sb.scan_info.values[15] == [1, 11, 101, 1001] + + +def test_parametrized_ND_grid_scan_generator(tmp_path): + import itertools + + def generate_nd_grid(dim_values_list): + return [list(combo) for combo in itertools.product(*dim_values_list)] + + N = 2 + dim_values = [[1, 2, 3], [100, 200]] + adjs = [DummyAdjustable(name=f"A{i}", ID=f"A{i}") for i in range(N)] + values = generate_nd_grid(dim_values) + + fake = FakeAcquisition("test_instrument", "test_pgroup") + sb = ScanBackend( + adjs, values, [fake], "nd_gen_2d", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + sb.running = True + sb.scan_loop() + + assert len(sb.scan_info.values) == 6 + assert sb.scan_info.values[0] == [1, 100] + assert sb.scan_info.values[5] == [3, 200] + + N = 3 + dim_values = [[1, 2], [10, 20, 30], [100]] + adjs = [DummyAdjustable(name=f"A{i}", ID=f"A{i}") for i in range(N)] + values = generate_nd_grid(dim_values) + + fake = FakeAcquisition("test_instrument", "test_pgroup") + sb = ScanBackend( + adjs, values, [fake], "nd_gen_3d", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + sb.running = True + sb.scan_loop() + + assert len(sb.scan_info.values) == 6 + assert sb.scan_info.values[0] == [1, 10, 100] + assert sb.scan_info.values[5] == [2, 30, 100] + + N = 5 + dim_values = [[1, 2], [10], [100, 200], [1000], [10000, 20000, 30000]] + adjs = [DummyAdjustable(name=f"A{i}", ID=f"A{i}") for i in range(N)] + values = generate_nd_grid(dim_values) + + fake = FakeAcquisition("test_instrument", "test_pgroup") + sb = ScanBackend( + adjs, values, [fake], "nd_gen_5d", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + sb.running = True + sb.scan_loop() + + assert len(sb.scan_info.values) == 12 + assert sb.scan_info.values[0] == [1, 10, 100, 1000, 10000] + assert sb.scan_info.values[11] == [2, 10, 200, 1000, 30000] + + +def test_ND_grid_with_readbacks_verification(tmp_path): + import itertools + + N = 3 + dim_values = [[0, 1], [10, 20], [100, 200, 300]] + adjs = [DummyAdjustable(name=f"Motor{i}", ID=f"Motor{i}") for i in range(N)] + values = [list(combo) for combo in itertools.product(*dim_values)] + + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, values, [fake], "grid_readbacks", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + sb.running = True + sb.scan_loop() + + assert len(sb.scan_info.readbacks) == 12 + + for i, expected_vals in enumerate(values): + assert sb.scan_info.readbacks[i] == expected_vals + + assert sb.scan_info.values[0] == [0, 10, 100] + assert sb.scan_info.readbacks[0] == [0, 10, 100] + + assert sb.scan_info.values[11] == [1, 20, 300] + assert sb.scan_info.readbacks[11] == [1, 20, 300] + + +def test_large_ND_grid_scan(tmp_path): + N = 3 + dim_values = [[0, 1, 2], [10, 20, 30], [100, 200, 300]] + + import itertools + adjs = [DummyAdjustable(name=f"Axis{i}", ID=f"Axis{i}") for i in range(N)] + values = [list(combo) for combo in itertools.product(*dim_values)] + + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, values, [fake], "large_grid_3d", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + sb.running = True + sb.scan_loop() + + assert len(sb.scan_info.values) == 27 + assert fake.call_count == 27 + + assert sb.scan_info.values[0] == [0, 10, 100] + assert sb.scan_info.values[8] == [0, 30, 300] + assert sb.scan_info.values[26] == [2, 30, 300] + + +def test_ND_grid_interrupted_midway(tmp_path): + N = 3 + dim_values = [[0, 1, 2], [10, 20], [100, 200, 300]] + + import itertools + adjs = [DummyAdjustable(name=f"M{i}", ID=f"M{i}") for i in range(N)] + values = [list(combo) for combo in itertools.product(*dim_values)] + + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, values, [fake], "interrupted_grid", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + original_do_step = sb.do_step + step_count = [0] + def counting_do_step(*args, **kwargs): + result = original_do_step(*args, **kwargs) + step_count[0] += 1 + if step_count[0] >= 5: + sb.running = False + return result + + sb.do_step = counting_do_step + sb.running = True + sb.scan_loop() + + assert len(sb.scan_info.values) == 5 + assert fake.call_count == 5 + assert step_count[0] == 5 + + +# Non-SFDAQ acquisitions + +def test_nonsfdaq_acquisition_only(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + nonsfdaq = NonSFDAQAcquisition(name="TestNonSFDAQ", default_dir=str(tmp_path)) + + sb = ScanBackend( + adjs, [[1], [2]], [nonsfdaq], "nonsfdaq_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + assert sb.filename != "nonsfdaq_scan" + assert sb.filename_sfdaq == "nonsfdaq_scan" + + sb.running = True + sb.scan_loop() + + assert nonsfdaq.call_count == 2 + + assert len(sb.scan_info.values) == 2 + assert sb.scan_info.values[0] == [1] + assert sb.scan_info.values[1] == [2] + + +def test_nonsfdaq_with_make_scan_sub_dir(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + default_dir = tmp_path / "nonsfdaq_default" + default_dir.mkdir() + + nonsfdaq = NonSFDAQAcquisition(name="TestNonSFDAQ", default_dir=str(default_dir)) + + sb = ScanBackend( + adjs, [[1]], [nonsfdaq], "subdir_scan", + [], ["ch"], [], 1, + "data", tmp_path, True, None, True, 1, None, None + ) + + sb.create_output_dirs() + + expected_subdir = default_dir / "data" / sb.filename + assert expected_subdir.exists() + + +def test_nonsfdaq_acquire_all_different_params(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + nonsfdaq = NonSFDAQAcquisition(name="TestNonSFDAQ", default_dir=str(tmp_path)) + + sb = ScanBackend( + adjs, [[1]], [nonsfdaq], "param_test", + [], ["ch"], [], 5, + "data", tmp_path, False, None, True, 1, None, None + ) + + filename = sb.get_filename(0) + sb.acquire_all(filename) + + assert nonsfdaq.last_filename == filename + assert nonsfdaq.last_n_pulses == 5 + + +def test_mixed_sfdaq_and_nonsfdaq_acquisitions(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + + fake = FakeAcquisition("test_instrument", "test_pgroup") + nonsfdaq = NonSFDAQAcquisition(name="Mixed", default_dir=str(tmp_path)) + + sb = ScanBackend( + adjs, [[1], [2]], [fake, nonsfdaq], "mixed_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + assert sb.filename != "mixed_scan" + + sb.running = True + sb.scan_loop() + + assert fake.call_count == 2 + assert nonsfdaq.call_count == 2 + + +def test_nonsfdaq_with_default_dir_none(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + nonsfdaq = NonSFDAQAcquisition(name="NoDir", default_dir=None) + + sb = ScanBackend( + adjs, [[1]], [nonsfdaq], "nodir_scan", + [], ["ch"], [], 1, + "data", tmp_path, True, None, True, 1, None, None + ) + + sb.create_output_dirs() + + sb.running = True + sb.scan_loop() + assert nonsfdaq.call_count == 1 + + +# return_to_initial_values=None + +def test_return_to_initial_values_none_user_says_yes(tmp_path, monkeypatch): + adjs = [DummyAdjustable(ID="M", initial_value=100, process_time=0)] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1], [2]], [fake], "interactive_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, None, 1, None, None + ) + + user_response = [True] + def mock_ask_yes_no(prompt): + return user_response[0] + + monkeypatch.setattr("slic.core.scanner.scanbackend.ask_Yes_no", mock_ask_yes_no) + + sb.run() + + assert adjs[0].get_current_value() == 100 + + +def test_return_to_initial_values_none_user_says_no(tmp_path, monkeypatch): + adjs = [DummyAdjustable(ID="M", initial_value=100, process_time=0)] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1], [2]], [fake], "interactive_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, None, 1, None, None + ) + + def mock_ask_yes_no(prompt): + return False + + monkeypatch.setattr("slic.core.scanner.scanbackend.ask_Yes_no", mock_ask_yes_no) + + sb.run() + + assert adjs[0].get_current_value() == 2 + + +# n_repeat=None (infinite) + +def test_n_repeat_none_infinite_with_early_stop(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1]], [fake], "infinite_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, None, None, None + ) + + original_scan_loop = sb.scan_loop + rep_count = [0] + def counting_scan_loop(*args, **kwargs): + result = original_scan_loop(*args, **kwargs) + rep_count[0] += 1 + if rep_count[0] >= 3: + sb.running = False + return result + + sb.scan_loop = counting_scan_loop + sb.running = True + sb.repeated_scan_loop() + + assert rep_count[0] == 3 + assert fake.call_count == 3 + + +def test_n_repeat_none_printable_summary(tmp_path, capsys): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1]], [fake], "infinite_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, None, None, None + ) + + summary = sb._make_summary() + + assert "repeat the following scan forever" in summary + + +# KeyboardInterrupt + +def test_keyboard_interrupt_handling(tmp_path, capsys): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1], [2], [3]], [fake], "keyboard_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + original_do_step = sb.do_step + step_count = [0] + def interrupt_do_step(*args, **kwargs): + step_count[0] += 1 + if step_count[0] >= 2: + raise KeyboardInterrupt + return original_do_step(*args, **kwargs) + + sb.do_step = interrupt_do_step + sb.run() + + assert step_count[0] == 2 + assert sb.running == False + + out = capsys.readouterr().out + assert "Stopped current DAQ tasks:" in out + + +# sensor without remote_plot (BUG) + +def test_sensor_without_remote_plot(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + sensor = DummySensor(name="TestSensor") + + sb = ScanBackend( + adjs, [[1], [2]], [fake], "sensor_only_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, sensor, None + ) + + sb.run() + + assert sensor.started + assert sensor.stopped + assert fake.call_count == 2 + + +# condition with n_repeat>1 + +def test_condition_with_multiple_repeats(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + condition = DummyCondition(repeats=6) + + sb = ScanBackend( + adjs, [[1]], [fake], "condition_repeat_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, condition, True, 3, None, None + ) + + sb.run() + + assert fake.call_count == 6 + + +# condition.stop() + +def test_stop_calls_condition_stop(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + condition = DummyCondition(repeats=0) + + sb = ScanBackend( + adjs, [[1]], [fake], "stop_condition_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, condition, True, 1, None, None + ) + + assert not condition._stopped + sb.stop() + assert condition._stopped + + +# step_info as static dict + +def test_step_info_as_static_dict(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1], [2]], [fake], "static_info_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + static_info = {"experiment": "test", "user": "testuser"} + + sb.running = True + sb.scan_loop(step_info=static_info) + + assert len(sb.scan_info.info) == 2 + assert sb.scan_info.info[0] == static_info + assert sb.scan_info.info[1] == static_info + + +# Empty values (BUG) + +def test_scan_with_empty_values_list(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [], [fake], "empty_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + sb.running = True + + sb.scan_loop() + + assert len(sb.scan_info.values) == 0 + assert fake.call_count == 0 + + +def test_scan_with_empty_acquisitions_list(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + + sb = ScanBackend( + adjs, [[1], [2]], [], "no_acq_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + sb.running = True + sb.scan_loop() + + assert len(sb.scan_info.values) == 2 + + +def test_scan_with_very_large_istep(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1]], [fake], "large_step_scan", + [], ["ch"], [], 1, + "data", tmp_path, True, None, True, 1, None, None + ) + + filename = sb.get_filename(12345) + + assert "step12345" in filename + + +def test_condition_with_minimum_repeats(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + condition = DummyCondition(repeats=1) + + sb = ScanBackend( + adjs, [[1]], [fake], "min_repeat_condition", + [], ["ch"], [], 1, + "data", tmp_path, False, condition, True, 1, None, None + ) + + sb.running = True + sb.scan_loop() + + assert fake.call_count == 1 + + +def test_remote_plot_without_sensor(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + remote_plot = DummyRemotePlot() + + sb = ScanBackend( + adjs, [[1], [2]], [fake], "plot_no_sensor_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, remote_plot + ) + + sb.run() + + assert not remote_plot.created + assert fake.call_count == 2 + + +def test_n_repeat_with_n_pulses_formatting(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb1 = ScanBackend( + adjs, [[1]], [fake], "scan1", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + summary1 = sb1._make_summary() + assert "1 pulse" in summary1 + + sb2 = ScanBackend( + adjs, [[1]], [fake], "scan2", + [], ["ch"], [], 5, + "data", tmp_path, False, None, True, 1, None, None + ) + summary2 = sb2._make_summary() + assert "5 pulses" in summary2 + + +def test_scan_stopped_before_first_iteration(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, [[1], [2], [3]], [fake], "stop_early_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + sb.running = False + sb.scan_loop() + + assert len(sb.scan_info.values) == 0 + assert fake.call_count == 0 + + +# Complete scan verification +# Verifies ALL values, readbacks, acquisitions, and return to initial + +def test_complete_scan_values_readbacks_and_return_to_initial(tmp_path): + adj1 = DummyAdjustable(name="Motor1", ID="M1", initial_value=0, process_time=0) + adj2 = DummyAdjustable(name="Motor2", ID="M2", initial_value=0, process_time=0) + adjs = [adj1, adj2] + + scan_values = [ + [10, 100], + [20, 200], + [30, 300] + ] + + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + adjs, scan_values, [fake], "complete_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, True, 1, None, None + ) + + initial_before = [adj1.get_current_value(), adj2.get_current_value()] + assert initial_before == [0, 0], "Initial values should be [0, 0]" + + sb.run() + + assert fake.call_count == 3, f"Expected 3 acquisitions, got {fake.call_count}" + + assert len(sb.scan_info.values) == 3, "Should have 3 recorded values" + assert sb.scan_info.values[0] == [10, 100], f"Step 0: expected [10, 100], got {sb.scan_info.values[0]}" + assert sb.scan_info.values[1] == [20, 200], f"Step 1: expected [20, 200], got {sb.scan_info.values[1]}" + assert sb.scan_info.values[2] == [30, 300], f"Step 2: expected [30, 300], got {sb.scan_info.values[2]}" + + assert len(sb.scan_info.readbacks) == 3, "Should have 3 readback records" + assert sb.scan_info.readbacks[0] == [10, 100], f"Readback 0: expected [10, 100], got {sb.scan_info.readbacks[0]}" + assert sb.scan_info.readbacks[1] == [20, 200], f"Readback 1: expected [20, 200], got {sb.scan_info.readbacks[1]}" + assert sb.scan_info.readbacks[2] == [30, 300], f"Readback 2: expected [30, 300], got {sb.scan_info.readbacks[2]}" + + final_values = [adj1.get_current_value(), adj2.get_current_value()] + assert final_values == [0, 0], f"Should return to [0, 0], but got {final_values}" + + assert sb.initial_values == [0, 0], f"Stored initial values should be [0, 0], got {sb.initial_values}" + + +def test_scan_without_return_to_initial_stays_at_last_value(tmp_path): + adj1 = DummyAdjustable(name="M1", ID="M1", initial_value=100, process_time=0) + adj2 = DummyAdjustable(name="M2", ID="M2", initial_value=200, process_time=0) + + scan_values = [ + [10, 20], + [30, 40], + [50, 60] + ] + + fake = FakeAcquisition("test_instrument", "test_pgroup") + + sb = ScanBackend( + [adj1, adj2], scan_values, [fake], "no_return_scan", + [], ["ch"], [], 1, + "data", tmp_path, False, None, False, 1, None, None + ) + + assert adj1.get_current_value() == 100 + assert adj2.get_current_value() == 200 + + sb.run() + + assert adj1.get_current_value() == 50, f"M1 should be at 50, got {adj1.get_current_value()}" + assert adj2.get_current_value() == 60, f"M2 should be at 60, got {adj2.get_current_value()}" + + +def test_acquisition_parameters_are_correct(tmp_path): + adjs = [DummyAdjustable(name="M", ID="M")] + fake = FakeAcquisition("test_instrument", "test_pgroup") + + scan_values = [[1], [2], [3]] + + sb = ScanBackend( + adjs, scan_values, [fake], "param_check_scan", + [], ["test_channel"], [], 5, + "data", tmp_path, False, None, True, 1, None, None + ) + + sb.run() + + assert fake.call_count == 3 + + assert fake.last_n_pulses == 5, f"Expected n_pulses=5, got {fake.last_n_pulses}" + + assert fake.last_channels == ["test_channel"], f"Expected channels=['test_channel'], got {fake.last_channels}" diff --git a/tests/test_scaninfo.py b/tests/test_scaninfo.py new file mode 100644 index 000000000..f3ed93c5c --- /dev/null +++ b/tests/test_scaninfo.py @@ -0,0 +1,488 @@ +import pytest +from slic.core.scanner.scaninfo import ScanInfo + + +class DummyAdjustable: + def __init__(self, name="adj", ID="id", units="u"): + self.name = name + self.ID = ID + self.units = units + + +# ScanInfo init + +@pytest.mark.parametrize( + "adjustables,values,suffix,expected_filename,expected_params", + [ + ( + [DummyAdjustable()], + [1, 2, 3], + "_scan_info.json", + "fileA_scan_info.json", + {"name": ["adj"], "Id": ["id"], "units": ["u"]}, + ), + ( + [DummyAdjustable("motorX", "M1", "mm")], + [10, 20], + ".meta", + "fileB.meta", + {"name": ["motorX"], "Id": ["M1"], "units": ["mm"]}, + ), + ( + [ + DummyAdjustable("motorX", "M1", "mm"), + DummyAdjustable("stageY", "S2", "deg"), + DummyAdjustable("lensZ", "L3", "cm"), + ], + [1, 2, 3], + "_extra.json", + "fileC_extra.json", + { + "name": ["motorX", "stageY", "lensZ"], + "Id": ["M1", "S2", "L3"], + "units": ["mm", "deg", "cm"], + }, + ), + ], +) +def test_init_creates_expected_filename(tmp_path, adjustables, values, suffix, expected_filename, expected_params): + base_dir = tmp_path + filename_base = expected_filename.split("_")[0].split(".")[0] + si = ScanInfo(filename_base, base_dir, adjustables, values, suffix=suffix) + + assert si.filename.endswith(expected_filename) + assert si.parameters == expected_params + assert si.values == [] + assert si.readbacks == [] + assert si.files == [] + assert si.info == [] + + +def test_init_with_empty_adjustables(tmp_path): + si = ScanInfo("empty_scan", tmp_path, [], []) + assert si.names == [] + assert si.IDs == [] + assert si.units == [] + assert si.parameters == {"name": [], "Id": [], "units": []} + + +class PartialAdjustable: + def __init__(self, has_name=True, has_id=True, has_units=True): + if has_name: + self.name = "test_name" + if has_id: + self.ID = "test_id" + if has_units: + self.units = "test_units" + + +@pytest.mark.parametrize( + "adjustable,expected_name,expected_id,expected_units", + [ + (PartialAdjustable(has_name=False, has_id=True, has_units=True), "noName", "test_id", "test_units"), + (PartialAdjustable(has_name=True, has_id=False, has_units=True), "test_name", "noID", "test_units"), + (PartialAdjustable(has_name=True, has_id=True, has_units=False), "test_name", "test_id", "noUnits"), + (PartialAdjustable(has_name=False, has_id=False, has_units=False), "noName", "noID", "noUnits"), + ], +) +def test_init_with_missing_attributes(tmp_path, adjustable, expected_name, expected_id, expected_units): + si = ScanInfo("partial_scan", tmp_path, [adjustable], [0]) + assert si.names == [expected_name] + assert si.IDs == [expected_id] + assert si.units == [expected_units] + + +# append + +def test_append(tmp_path): + si = ScanInfo("fileX", tmp_path, [DummyAdjustable("A", "1", "u")], [0]) + + si.append([1, 2, 3], [10, 20, 30], ["f1.dat", "f2.dat", "f3.dat"], {"note": "phase1"}) + assert si.values == [[1, 2, 3]] + assert si.readbacks == [[10, 20, 30]] + assert si.files == [["f1.dat", "f2.dat", "f3.dat"]] + assert si.info == [{"note": "phase1"}] + + si.append([4, 5], [40, 50], ["f4.dat", "f5.dat"], lambda: {"note": "auto_phase2"}) + + assert si.values == [[1, 2, 3], [4, 5]] + assert si.readbacks == [[10, 20, 30], [40, 50]] + assert si.files == [["f1.dat", "f2.dat", "f3.dat"], ["f4.dat", "f5.dat"]] + assert si.info == [{"note": "phase1"}, {"note": "auto_phase2"}] + + +def test_append_with_empty_lists(tmp_path): + si = ScanInfo("empty", tmp_path, [DummyAdjustable()], [0]) + si.append([], [], [], {}) + + assert si.values == [[]] + assert si.readbacks == [[]] + assert si.files == [[]] + assert si.info == [{}] + + +@pytest.mark.parametrize( + "values_len,readbacks_len,files_len", + [ + (3, 2, 1), + (1, 3, 2), + (5, 1, 5), + ], +) +def test_append_with_mismatched_list_lengths(tmp_path, values_len, readbacks_len, files_len): + si = ScanInfo("mismatch", tmp_path, [DummyAdjustable()], [0]) + + values = list(range(values_len)) + readbacks = list(range(readbacks_len)) + files = [f"f{i}.dat" for i in range(files_len)] + + si.append(values, readbacks, files, {"note": "mismatch"}) + + assert len(si.values[0]) == values_len + assert len(si.readbacks[0]) == readbacks_len + assert len(si.files[0]) == files_len + + +def test_append_info_with_complex_nested_dict(tmp_path): + si = ScanInfo("nested", tmp_path, [DummyAdjustable()], [0]) + + complex_info = { + "level1": { + "level2": { + "level3": { + "data": [1, 2, 3], + "meta": {"key": "value"} + } + } + }, + "list": [{"a": 1}, {"b": 2}] + } + + si.append([1], [1], ["f.dat"], complex_info) + assert si.info[0] == complex_info + + +def test_very_long_lists(tmp_path): + si = ScanInfo("long", tmp_path, [DummyAdjustable()], [0]) + + long_values = list(range(10000)) + long_readbacks = list(range(10000)) + long_files = [f"file_{i}.dat" for i in range(10000)] + + si.append(long_values, long_readbacks, long_files, {"note": "big"}) + + assert len(si.values[0]) == 10000 + assert len(si.readbacks[0]) == 10000 + assert len(si.files[0]) == 10000 + + +# info callable + +def test_callable_info_that_raises_exception(tmp_path): + si = ScanInfo("error_test", tmp_path, [DummyAdjustable()], [0]) + + def bad_info(): + raise ValueError("Intentional error") + + with pytest.raises(ValueError, match="Intentional error"): + si.append([1], [1], ["f.dat"], bad_info) + + +def test_callable_info_returns_none(tmp_path): + si = ScanInfo("none_test", tmp_path, [DummyAdjustable()], [0]) + si.append([1], [1], ["f.dat"], lambda: None) + assert si.info == [None] + + +def test_info_with_none_directly(tmp_path): + si = ScanInfo("none_direct", tmp_path, [DummyAdjustable()], [0]) + si.append([1], [1], ["f.dat"], None) + assert si.info == [None] + + +# write and to_dict + +def test_write_and_to_dict(tmp_path, monkeypatch): + base_dir = tmp_path + si = ScanInfo("scanTest", base_dir, [ + DummyAdjustable("motorX", "M1", "mm"), + DummyAdjustable("stageY", "S2", "deg"), + ], [0], suffix="_info.json") + + si.append([1.0, 2.0], [1.1, 2.1], ["f1.dat", "f2.dat"], {"phase": "init"}) + si.append([3.0, 4.0], [3.1, 4.1], ["f3.dat", "f4.dat"], {"phase": "end"}) + + last_call = {} + def mock_json_save(data, filename): + last_call['data'] = data + last_call['filename'] = filename + + monkeypatch.setattr('slic.core.scanner.scaninfo.json_save', mock_json_save) + + si.write() + + assert last_call['filename'] == si.filename + + expected = si.to_dict() + assert last_call['data'] == expected + + +def test_to_dict_complete_structure(tmp_path): + si = ScanInfo("scan_test", tmp_path, [DummyAdjustable("M", "ID1", "mm")], [0, 1, 2]) + + si.append([1.0], [1.1], ["f1.dat"], {"note": "test"}) + + result = si.to_dict() + + expected_keys = { + "scan_parameters", + "scan_values_all", + "scan_values", + "scan_readbacks", + "scan_files", + "scan_info", + } + assert set(result.keys()) == expected_keys + + assert result["scan_parameters"] == {"name": ["M"], "Id": ["ID1"], "units": ["mm"]} + assert result["scan_values_all"] == [0, 1, 2] + assert result["scan_values"] == [[1.0]] + assert result["scan_readbacks"] == [[1.1]] + assert result["scan_files"] == [["f1.dat"]] + assert result["scan_info"] == [{"note": "test"}] + + +# update + +def test_update_integration(tmp_path, monkeypatch): + si = ScanInfo("scanX", tmp_path, [DummyAdjustable("M", "ID", "mm")], [0], suffix=".json") + + last_call = {} + def mock_json_save(data, filename): + last_call['data'] = data + last_call['filename'] = filename + + monkeypatch.setattr('slic.core.scanner.scaninfo.json_save', mock_json_save) + + si.update([1, 2], [10, 20], ["f1.dat", "f2.dat"], {"phase": "start"}) + + assert si.values == [[1, 2]] + assert si.readbacks == [[10, 20]] + assert si.files == [["f1.dat", "f2.dat"]] + assert si.info == [{"phase": "start"}] + + assert last_call['filename'] == si.filename + assert last_call['data'] == si.to_dict() + + +# to_sfdaq_dict + +def test_to_sfdaq_dict_filled_example(tmp_path): + si = ScanInfo( + filename_base="scanAlpha", + base_dir=tmp_path, + adjustables=[ + DummyAdjustable("motorX", "M1", "mm"), + DummyAdjustable("stageY", "S2", "deg"), + DummyAdjustable("lensZ", "L3", "cm"), + ], + values=[0, 1, 2], + suffix="_scan_info.json", + ) + + result_empty = si.to_sfdaq_dict() + assert result_empty["scan_values"] is None + assert result_empty["scan_readbacks"] is None + + si.append( + [1.0, 2.0, 3.0], + [1.1, 2.1, 3.1], + ["f1.dat"], + {"note": "first run"} + ) + + si.append( + [4.0, 5.0, 6.0], + [4.1, 5.1, 6.1], + ["f2.dat"], + {"note": "second run"} + ) + + result = si.to_sfdaq_dict() + + expected_keys = { + "scan_name", "name", "Id", "units", + "offset", "conversion_factor", + "scan_values", "scan_readbacks", "scan_readbacks_raw", + } + assert set(result.keys()) == expected_keys + + assert result["scan_values"] == [4.0, 5.0, 6.0] + assert result["scan_readbacks"] == [4.1, 5.1, 6.1] + assert result["scan_readbacks_raw"] == [4.1, 5.1, 6.1] + + assert result["scan_name"] == "scanAlpha" + assert result["name"] == ["motorX", "stageY", "lensZ"] + assert result["Id"] == ["M1", "S2", "L3"] + assert result["units"] == ["mm", "deg", "cm"] + assert result["offset"] == [0, 0, 0] + assert result["conversion_factor"] == [1, 1, 1] + + expected_dict = { + "scan_name": "scanAlpha", + "name": ["motorX", "stageY", "lensZ"], + "Id": ["M1", "S2", "L3"], + "units": ["mm", "deg", "cm"], + "offset": [0, 0, 0], + "conversion_factor": [1, 1, 1], + "scan_values": [4.0, 5.0, 6.0], + "scan_readbacks": [4.1, 5.1, 6.1], + "scan_readbacks_raw": [4.1, 5.1, 6.1], + } + + assert result == expected_dict + + +def test_to_sfdaq_dict_with_single_adjustable(tmp_path): + si = ScanInfo("single", tmp_path, [DummyAdjustable("motor", "M1", "deg")], [0]) + si.append([5.0], [5.1], ["data.dat"], {"run": 1}) + + result = si.to_sfdaq_dict() + + assert result["name"] == ["motor"] + assert result["Id"] == ["M1"] + assert result["units"] == ["deg"] + assert result["offset"] == [0] + assert result["conversion_factor"] == [1] + assert result["scan_values"] == [5.0] + assert result["scan_readbacks"] == [5.1] + + +def test_to_sfdaq_dict_filename_with_slash(tmp_path): + si = ScanInfo("path/to/scan", tmp_path, [DummyAdjustable()], [0]) + result = si.to_sfdaq_dict() + assert result["scan_name"] == "path_to_scan" + + +def test_to_sfdaq_dict_with_empty_adjustables_list(tmp_path): + si = ScanInfo("empty_adj", tmp_path, [], []) + result = si.to_sfdaq_dict() + + assert result["name"] == [] + assert result["Id"] == [] + assert result["units"] == [] + assert result["offset"] == [] + assert result["conversion_factor"] == [] + assert result["scan_values"] is None + assert result["scan_readbacks"] is None + + +def test_append_then_to_sfdaq_returns_last_values(tmp_path): + si = ScanInfo("last_test", tmp_path, [DummyAdjustable("X", "X1", "m")], [0]) + + si.append([10.0], [10.5], ["f1.dat"], {"n": 1}) + si.append([20.0], [20.5], ["f2.dat"], {"n": 2}) + si.append([30.0], [30.5], ["f3.dat"], {"n": 3}) + + result = si.to_sfdaq_dict() + + assert result["scan_values"] == [30.0] + assert result["scan_readbacks"] == [30.5] + assert result["scan_readbacks_raw"] == [30.5] + + +# repr + +def test_repr(tmp_path): + si = ScanInfo("test_scan", tmp_path, [DummyAdjustable()], [0], suffix="_info.json") + expected_path = str(tmp_path / "test_scan_info.json") + assert repr(si) == f"Scan info in {expected_path}" + + +# filename edge cases + +def test_filename_base_already_contains_suffix(tmp_path): + si = ScanInfo("scan_scan_info.json", tmp_path, [DummyAdjustable()], [0], suffix="_scan_info.json") + assert si.filename.endswith("scan_scan_info.json_scan_info.json") + + +def test_filename_with_multiple_slashes(tmp_path): + si = ScanInfo("path/to/deep/scan", tmp_path, [DummyAdjustable()], [0]) + result = si.to_sfdaq_dict() + assert result["scan_name"] == "path_to_deep_scan" + assert "/" not in result["scan_name"] + + +def test_suffix_empty_string(tmp_path): + si = ScanInfo("test", tmp_path, [DummyAdjustable()], [0], suffix="") + assert si.filename == str(tmp_path / "test") + assert not si.filename.endswith(".json") + + +def test_suffix_no_extension(tmp_path): + si = ScanInfo("test", tmp_path, [DummyAdjustable()], [0], suffix="_data") + assert si.filename.endswith("test_data") + + +def test_base_dir_with_trailing_slash(tmp_path): + base_with_slash = str(tmp_path) + "/" + si = ScanInfo("test", base_with_slash, [DummyAdjustable()], [0]) + assert "test_scan_info.json" in si.filename + + +def test_filename_base_with_dots(tmp_path): + si = ScanInfo("scan.v1.0.test", tmp_path, [DummyAdjustable()], [0], suffix=".json") + assert "scan.v1.0.test.json" in si.filename + + +def test_multiple_slashes_and_underscores_in_filename_base(tmp_path): + si = ScanInfo("path/to_scan/with_underscores", tmp_path, [DummyAdjustable()], [0]) + result = si.to_sfdaq_dict() + assert result["scan_name"] == "path_to_scan_with_underscores" + + +def test_adjustable_with_special_characters_in_name(tmp_path): + adj = DummyAdjustable("motor:X@123!", "ID#1", "µm/s²") + si = ScanInfo("special", tmp_path, [adj], [0]) + + assert si.names == ["motor:X@123!"] + assert si.IDs == ["ID#1"] + assert si.units == ["µm/s²"] + + +# Integration + +# Tests complete workflow: init -> multiple updates -> verify final JSON structure +# Verifies that each update() call correctly appends data and saves to JSON + +def test_full_integration_workflow(tmp_path, monkeypatch): + saved_data = [] + def mock_json_save(data, filename): + saved_data.append({"data": data, "filename": filename}) + + monkeypatch.setattr('slic.core.scanner.scaninfo.json_save', mock_json_save) + + si = ScanInfo( + "integration_scan", + tmp_path, + [DummyAdjustable("motorA", "MA", "mm"), DummyAdjustable("motorB", "MB", "deg")], + [0, 10, 20], + suffix=".json" + ) + + si.update([1.0, 2.0], [1.1, 2.1], ["file1.dat"], {"step": 1}) + assert len(saved_data) == 1 + + si.update([3.0, 4.0], [3.1, 4.1], ["file2.dat"], {"step": 2}) + assert len(saved_data) == 2 + + final_data = saved_data[-1]["data"] + assert final_data["scan_parameters"] == { + "name": ["motorA", "motorB"], + "Id": ["MA", "MB"], + "units": ["mm", "deg"] + } + assert final_data["scan_values"] == [[1.0, 2.0], [3.0, 4.0]] + assert final_data["scan_readbacks"] == [[1.1, 2.1], [3.1, 4.1]] + assert final_data["scan_files"] == [["file1.dat"], ["file2.dat"]] + assert final_data["scan_info"] == [{"step": 1}, {"step": 2}]