feat(signals): enhance type hints for signal data and validation

This commit is contained in:
2025-12-12 14:13:13 +01:00
parent 6dbe1e1f57
commit ef0757d856

View File

@@ -29,6 +29,9 @@ __all__ = [
"AsyncMultiSignal", "AsyncMultiSignal",
] ]
SupportedSignalTypes = (int, float, str, bool, list, np.ndarray)
SupportedSignalTypesUnion = int | float | str | bool | list | np.ndarray
class SignalInfo(BaseModel): class SignalInfo(BaseModel):
""" """
@@ -716,7 +719,10 @@ class DynamicSignal(BECMessageSignal):
@typechecked @typechecked
def put( def put(
self, self,
value: messages.DeviceMessage | dict[str, dict[Literal["value", "timestamp"], Any]], value: (
messages.DeviceMessage
| dict[str, dict[Literal["value", "timestamp"], SupportedSignalTypesUnion]]
),
*, *,
metadata: dict | None = None, metadata: dict | None = None,
async_update: dict[Literal["type", "max_shape", "index"], Any] | None = None, async_update: dict[Literal["type", "max_shape", "index"], Any] | None = None,
@@ -755,6 +761,16 @@ class DynamicSignal(BECMessageSignal):
elif self.acquisition_group is not None: elif self.acquisition_group is not None:
metadata["acquisition_group"] = self.acquisition_group metadata["acquisition_group"] = self.acquisition_group
# verify that signal data is of supported type
for signal_name, signal_data in value.items():
if "value" not in signal_data:
raise ValueError(f"Signal data for {signal_name} must contain 'value' key.")
if not isinstance(signal_data["value"], SupportedSignalTypes):
raise ValueError(
f"Signal data for {signal_name} must be of type {SupportedSignalTypes}, "
f"got {type(signal_data['value']).__name__}."
)
msg = messages.DeviceMessage(signals=value, metadata=metadata) msg = messages.DeviceMessage(signals=value, metadata=metadata)
except ValidationError as exc: except ValidationError as exc:
raise ValueError(f"Error setting signal {self.name}: {exc}") from exc raise ValueError(f"Error setting signal {self.name}: {exc}") from exc
@@ -973,7 +989,7 @@ class AsyncSignal(DynamicSignal):
def put( def put(
self, self,
value: Any, value: SupportedSignalTypesUnion,
timestamp: float | None = None, timestamp: float | None = None,
async_update: dict[Literal["type", "max_shape", "index"], Any] | None = None, async_update: dict[Literal["type", "max_shape", "index"], Any] | None = None,
acquisition_group: str | None = None, acquisition_group: str | None = None,
@@ -998,7 +1014,7 @@ class AsyncSignal(DynamicSignal):
def set( def set(
self, self,
value: Any, value: SupportedSignalTypesUnion,
timestamp: float | None = None, timestamp: float | None = None,
async_update: dict[Literal["type", "max_shape", "index"], Any] | None = None, async_update: dict[Literal["type", "max_shape", "index"], Any] | None = None,
acquisition_group: str | None = None, acquisition_group: str | None = None,