fix: device sim params can be set through init

This commit is contained in:
appel_c 2024-07-03 08:22:31 +02:00
parent 69105332a4
commit f481c1f812
8 changed files with 141 additions and 114 deletions

View File

@ -63,7 +63,7 @@ class SimCameraSetup(CustomDetectorMixin):
self.parent.h5_writer.prepare(
file_path=self.parent.filepath.get(), h5_entry="/entry/data/data"
)
self.publish_file_location(done=False)
self.publish_file_location(done=False, successful=False)
self.parent.stopped = False
def on_unstage(self) -> None:
@ -122,13 +122,15 @@ class SimCamera(PSIDetectorBase):
def __init__(
self, name, *, kind=None, parent=None, sim_init: dict = None, device_manager=None, **kwargs
):
self.init_sim_params = sim_init
self.sim_init = sim_init
self._registered_proxies = {}
self.sim = self.sim_cls(parent=self, **kwargs)
self.h5_writer = H5Writer()
super().__init__(
name=name, parent=parent, kind=kind, device_manager=device_manager, **kwargs
)
if self.sim_init:
self.sim.set_init(self.sim_init)
@property
def registered_proxies(self) -> None:

View File

@ -82,13 +82,13 @@ class SimulatedDataBase(ABC):
The class provides the following methods:
- execute_simulation_method: execute a method from the simulated data class or reroute execution to device proxy class
- sim_select_model: select the active simulation model
- sim_params: get the parameters for the active simulation mdoel
- select_model: select the active simulation model
- params: get the parameters for the active simulation mdoel
- sim_models: get the available simulation models
- update_sim_state: update the simulated state of the device
"""
USER_ACCESS = ["sim_params", "sim_select_model", "sim_get_models", "sim_show_all"]
USER_ACCESS = ["params", "select_model", "get_models", "show_all"]
def __init__(self, *args, parent=None, **kwargs) -> None:
"""
@ -122,7 +122,12 @@ class SimulatedDataBase(ABC):
return method(*args, **kwargs)
raise SimulatedDataException(f"Method {method} is not available for {self.parent.name}")
# TODO remove after refactoring code in main
def sim_select_model(self, model: str) -> None:
"""Select the active simulation model."""
self.select_model(model)
def select_model(self, model: str) -> None:
"""
Method to select the active simulation model.
It will initiate the model_cls and parameters for the model.
@ -138,7 +143,7 @@ class SimulatedDataBase(ABC):
print(self._get_table_active_simulation())
@property
def sim_params(self) -> dict:
def params(self) -> dict:
"""
Property that returns the parameters for the active simulation model. It can also
be used to set the parameters for the active simulation updating the parameters of the model.
@ -147,17 +152,17 @@ class SimulatedDataBase(ABC):
dict: Parameters for the active simulation model.
The following example shows how to update the noise parameter of the current simulation.
>>> dev.<device>.sim.sim_params = {"noise": "poisson"}
>>> dev.<device>.sim.params = {"noise": "poisson"}
"""
return self._params
@sim_params.setter
def sim_params(self, params: dict):
@params.setter
def params(self, params: dict):
"""
Method to set the parameters for the active simulation model.
"""
for k, v in params.items():
if k in self.sim_params:
if k in self.params:
if k == "noise":
self._params[k] = NoiseType(v)
elif k == "hot_pixel_types":
@ -167,9 +172,9 @@ class SimulatedDataBase(ABC):
if isinstance(self._model, Model) and k in self._model_params:
self._model_params[k].value = v
else:
raise SimulatedDataException(f"Parameter {k} not found in {self.sim_params}.")
raise SimulatedDataException(f"Parameter {k} not found in {self.params}.")
def sim_get_models(self) -> list:
def get_models(self) -> list:
"""
Method to get the all available simulation models.
"""
@ -221,7 +226,7 @@ class SimulatedDataBase(ABC):
table = PrettyTable()
table.title = f"Currently active model: {self._model}"
table.field_names = ["Parameter", "Value", "Type"]
for k, v in self.sim_params.items():
for k, v in self.params.items():
table.add_row([k, f"{v}", f"{type(v)}"])
table._min_width["Parameter"] = 25 if width > 75 else width // 3
table._min_width["Type"] = 25 if width > 75 else width // 3
@ -238,16 +243,16 @@ class SimulatedDataBase(ABC):
table.title = "Available methods within the simulation module"
table.field_names = ["Method", "Docstring"]
table.add_row([self.sim_get_models.__name__, f"{self.sim_get_models.__doc__}"])
table.add_row([self.sim_select_model.__name__, self.sim_select_model.__doc__])
table.add_row(["sim_params", self.__class__.sim_params.__doc__])
table.add_row([self.get_models.__name__, f"{self.get_models.__doc__}"])
table.add_row([self.select_model.__name__, self.select_model.__doc__])
table.add_row(["params", self.__class__.params.__doc__])
table.max_table_width = width
table._min_table_width = width
table.align["Docstring"] = "l"
return table
def sim_show_all(self):
def show_all(self):
"""Returns a summary about the active simulation and available methods."""
width = 150
print(self._get_table_active_simulation(width=width))
@ -260,6 +265,15 @@ class SimulatedDataBase(ABC):
table._min_table_width = width
print(table)
def set_init(self, sim_init: dict["model", "params"]) -> None:
"""Set the initial simulation parameters.
Args:
sim_init (dict["model"]): Dictionary to initiate parameters of the simulation.
"""
self.select_model(sim_init.get("model"))
self.params = sim_init.get("params", {})
class SimulatedPositioner(SimulatedDataBase):
"""Simulated data class for a positioner."""
@ -317,7 +331,7 @@ class SimulatedDataMonitor(SimulatedDataBase):
def _init_default(self) -> None:
"""Initialize the default parameters for the simulated data."""
self.sim_select_model("ConstantModel")
self.select_model("ConstantModel")
def get_model_cls(self, model: str) -> any:
"""Get the class for the active simulation model."""
@ -411,14 +425,14 @@ class SimulatedDataMonitor(SimulatedDataBase):
Returns:
float: Value computed by the active model.
"""
mot_name = self.sim_params["ref_motor"]
mot_name = self.params["ref_motor"]
if self.parent.device_manager and mot_name in self.parent.device_manager.devices:
motor_pos = self.parent.device_manager.devices[mot_name].obj.read()[mot_name]["value"]
else:
motor_pos = 0
method = self._model
value = int(method.eval(params=self._model_params, x=motor_pos))
return self._add_noise(value, self.sim_params["noise"], self.sim_params["noise_multiplier"])
return self._add_noise(value, self.params["noise"], self.params["noise_multiplier"])
def _add_noise(self, v: int, noise: NoiseType, noise_multiplier: float) -> int:
"""
@ -478,8 +492,8 @@ class SimulatedDataWaveform(SimulatedDataMonitor):
size = size[0] if isinstance(size, tuple) else size
method = self._model
value = method.eval(params=self._model_params, x=np.array(range(size)))
value *= self.sim_params["amplitude"] / np.max(value)
return self._add_noise(value, self.sim_params["noise"], self.sim_params["noise_multiplier"])
value *= self.params["amplitude"] / np.max(value)
return self._add_noise(value, self.params["noise"], self.params["noise_multiplier"])
def _add_noise(self, v: np.ndarray, noise: NoiseType, noise_multiplier: float) -> np.ndarray:
"""Add noise to the simulated data.
@ -515,7 +529,7 @@ class SimulatedDataCamera(SimulatedDataBase):
Use the default model "Gaussian".
"""
self.sim_select_model(SimulationType2D.GAUSSIAN)
self.select_model(SimulationType2D.GAUSSIAN)
def init_2D_models(self) -> dict:
"""
@ -609,13 +623,13 @@ class SimulatedDataCamera(SimulatedDataBase):
"""Compute a return value for SimulationType2D constant."""
try:
shape = self.parent.image_shape.get()
v = self.sim_params.get("amplitude") * np.ones(shape, dtype=np.float32)
v = self._add_noise(v, self.sim_params["noise"], self.sim_params["noise_multiplier"])
v = self.params.get("amplitude") * np.ones(shape, dtype=np.float32)
v = self._add_noise(v, self.params["noise"], self.params["noise_multiplier"])
return self._add_hot_pixel(
v,
coords=self.sim_params["hot_pixel_coords"],
hot_pixel_types=self.sim_params["hot_pixel_types"],
values=self.sim_params["hot_pixel_values"],
coords=self.params["hot_pixel_coords"],
hot_pixel_types=self.params["hot_pixel_types"],
values=self.params["hot_pixel_values"],
)
except SimulatedDataException as exc:
raise SimulatedDataException(
@ -634,9 +648,9 @@ class SimulatedDataCamera(SimulatedDataBase):
"""
try:
amp = self.sim_params.get("amplitude")
cov = self.sim_params.get("covariance")
cen_off = self.sim_params.get("center_offset")
amp = self.params.get("amplitude")
cov = self.params.get("covariance")
cen_off = self.params.get("center_offset")
shape = self.sim_state[self.parent.image_shape.name]["value"]
pos, offset, cov, amp = self._prepare_params_gauss(
amp=amp, cov=cov, offset=cen_off, shape=shape
@ -644,15 +658,13 @@ class SimulatedDataCamera(SimulatedDataBase):
v = self._compute_multivariate_gaussian(pos=pos, cen_off=offset, cov=cov, amp=amp)
v = self._add_noise(
v,
noise=self.sim_params["noise"],
noise_multiplier=self.sim_params["noise_multiplier"],
v, noise=self.params["noise"], noise_multiplier=self.params["noise_multiplier"]
)
return self._add_hot_pixel(
v,
coords=self.sim_params["hot_pixel_coords"],
hot_pixel_types=self.sim_params["hot_pixel_types"],
values=self.sim_params["hot_pixel_values"],
coords=self.params["hot_pixel_coords"],
hot_pixel_types=self.params["hot_pixel_types"],
values=self.params["hot_pixel_values"],
)
except SimulatedDataException as exc:
raise SimulatedDataException(

View File

@ -48,6 +48,7 @@ class SimFlyer(Device, FlyerInterface):
parent=None,
kind=None,
device_manager=None,
sim_init: dict = None,
# TODO remove after refactoring config
delay: int = 1,
update_frequency: int = 100,
@ -55,6 +56,7 @@ class SimFlyer(Device, FlyerInterface):
):
self.sim = self.sim_cls(parent=self, **kwargs)
self.sim_init = sim_init
self.precision = precision
self.device_manager = device_manager
self._registered_proxies = {}
@ -62,6 +64,8 @@ class SimFlyer(Device, FlyerInterface):
super().__init__(name=name, parent=parent, kind=kind, **kwargs)
self.sim.sim_state[self.name] = self.sim.sim_state.pop(self.readback.name, None)
self.readback.name = self.name
if self.sim_init:
self.sim.set_init(self.sim_init)
@property
def registered_proxies(self) -> None:

View File

@ -2,6 +2,9 @@ from abc import ABC, abstractmethod
from collections import defaultdict
import h5py
# Necessary import to allow h5py to open compressed h5files.
# pylint: disable=unused-import
import hdf5plugin # noqa: F401
import numpy as np
from ophyd import Kind, Staged
@ -87,20 +90,7 @@ class SlitProxy(DeviceProxy):
To update for instance the pixel_size directly, you can directly access the DeviceConfig via
`dev.eiger.get_device_config()` or update it `dev.eiger.get_device_config({'eiger' : {'pixel_size': 0.1}})`
slit_sim:
readoutPriority: baseline
deviceClass: SlitProxy
deviceConfig:
eiger:
signal_name: image
center_offset: [0, 0] # [x,y]
covariance: [[1000, 500], [200, 1000]] # [[x,x],[y,y]]
pixel_size: 0.01
ref_motors: [samx, samy]
slit_width: [1, 1]
motor_dir: [0, 1] # x:0 , y:1, z:2 coordinates
enabled: true
readOnly: false
An example for the configuration of this is device is in ophyd_devices.configs.ophyd_devices_simulation.yaml
"""
USER_ACCESS = ["enabled", "lookup", "help"]
@ -126,7 +116,7 @@ class SlitProxy(DeviceProxy):
np.ndarray: Lookup table for the simulated camera.
"""
device_obj = self.device_manager.devices.get(device_name).obj
params = device_obj.sim.sim_params
params = device_obj.sim.params
shape = device_obj.image_shape.get()
params.update(
{
@ -197,18 +187,10 @@ class SlitProxy(DeviceProxy):
class H5ImageReplayProxy(DeviceProxy):
"""This Proxy class can be used to replay images from an h5 file.
If the number of requested images is larger than the number of available iamges, the images will be replayed from the beginning.
If the number of requested images is larger than the number of available iamges,
the images will be replayed from the beginning.
h5_image_sim:
readoutPriority: baseline
deviceClass: H5ImageReplayProxy
deviceConfig:
eiger:
signal_name: image
file_source: /path/to/h5file.h5
h5_entry: /entry/data
enabled: true
readOnly: false
An example for the configuration of this is device is in ophyd_devices.configs.ophyd_devices_simulation.yaml
"""
USER_ACCESS = ["file_source", "h5_entry"]

View File

@ -1,12 +1,9 @@
from typing import Literal
import numpy as np
from bec_lib import messages
from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger
from ophyd import Component as Cpt
from ophyd import Device, Kind
from typeguard import typechecked
from ophyd_devices.interfaces.base_classes.psi_detector_base import (
CustomDetectorMixin,
@ -60,7 +57,7 @@ class SimMonitor(Device):
**kwargs,
):
self.precision = precision
self.init_sim_params = sim_init
self.sim_init = sim_init
self.device_manager = device_manager
self.sim = self.sim_cls(parent=self, **kwargs)
self._registered_proxies = {}
@ -68,6 +65,8 @@ class SimMonitor(Device):
super().__init__(name=name, parent=parent, kind=kind, **kwargs)
self.sim.sim_state[self.name] = self.sim.sim_state.pop(self.readback.name, None)
self.readback.name = self.name
if self.sim_init:
self.sim.set_init(self.sim_init)
@property
def registered_proxies(self) -> None:
@ -112,7 +111,7 @@ class SimMonitorAsyncPrepare(CustomDetectorMixin):
if self.parent.scaninfo.scan_msg is None:
return
metadata = self.parent.scaninfo.scan_msg.metadata
metadata.update({"async_update": self.parent.async_update})
metadata.update({"async_update": self.parent.async_update.get()})
msg = messages.DeviceMessage(
signals={self.parent.readback.name: self.parent.data_buffer},
@ -163,6 +162,7 @@ class SimMonitorAsync(PSIDetectorBase):
readback = Cpt(ReadOnlySignal, value=BIT_DEPTH(0), kind=Kind.hinted, compute_readback=True)
current_trigger = Cpt(SetableSignal, value=BIT_DEPTH(0), kind=Kind.config)
async_update = Cpt(SetableSignal, value="extend", kind=Kind.config)
SUB_READBACK = "readback"
SUB_PROGRESS = "progress"
@ -171,7 +171,7 @@ class SimMonitorAsync(PSIDetectorBase):
def __init__(
self, name, *, sim_init: dict = None, parent=None, kind=None, device_manager=None, **kwargs
):
self.init_sim_params = sim_init
self.sim_init = sim_init
self.device_manager = device_manager
self.sim = self.sim_cls(parent=self, **kwargs)
self._registered_proxies = {}
@ -182,7 +182,8 @@ class SimMonitorAsync(PSIDetectorBase):
self.sim.sim_state[self.name] = self.sim.sim_state.pop(self.readback.name, None)
self.readback.name = self.name
self._data_buffer = {"value": [], "timestamp": []}
self._async_update = "extend"
if self.sim_init:
self.sim.set_init(self.sim_init)
@property
def data_buffer(self) -> list:
@ -193,18 +194,3 @@ class SimMonitorAsync(PSIDetectorBase):
def registered_proxies(self) -> None:
"""Dictionary of registered signal_names and proxies."""
return self._registered_proxies
@property
def async_update(self) -> str:
"""Update method for the asynchronous monitor."""
return self._async_update
@async_update.setter
@typechecked
def async_update(self, value: Literal["extend", "append"]) -> None:
"""Set the update method for the asynchronous monitor.
Args:
value (str): Can only be "extend" or "append".
"""
self._async_update = value

View File

@ -78,7 +78,7 @@ class SimPositioner(Device, PositionerBase):
self.delay = delay
self.device_manager = device_manager
self.precision = precision
self.init_sim_params = sim_init
self.sim_init = sim_init
self._registered_proxies = {}
self.update_frequency = update_frequency
@ -94,11 +94,8 @@ class SimPositioner(Device, PositionerBase):
assert len(limits) == 2
self.low_limit_travel.put(limits[0])
self.high_limit_travel.put(limits[1])
# @property
# def connected(self):
# """Return the connected state of the simulated device."""
# return self.dummy_controller.connected
if self.sim_init:
self.sim.set_init(self.sim_init)
@property
def limits(self):

View File

@ -7,6 +7,7 @@ from ophyd import Component as Cpt
from ophyd import Device, DeviceStatus, Kind
from ophyd_devices.sim.sim_data import SimulatedDataWaveform
from ophyd_devices.sim.sim_exception import DeviceStop
from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal
from ophyd_devices.utils.bec_scaninfo_mixin import BecScaninfoMixin
@ -60,7 +61,7 @@ class SimWaveform(Device):
self, name, *, kind=None, parent=None, sim_init: dict = None, device_manager=None, **kwargs
):
self.device_manager = device_manager
self.init_sim_params = sim_init
self.sim_init = sim_init
self._registered_proxies = {}
self.sim = self.sim_cls(parent=self, **kwargs)
@ -69,6 +70,8 @@ class SimWaveform(Device):
self._staged = False
self.scaninfo = None
self._update_scaninfo()
if self.sim_init:
self.sim.set_init(self.sim_init)
@property
def registered_proxies(self) -> None:

View File

@ -96,6 +96,46 @@ def flyer(name="flyer"):
yield fly
def test_camera_with_sim_init():
"""Test to see if the sim init parameters are passed to the device"""
dm = DMMock()
sim = SimCamera(name="sim", device_manager=dm)
assert sim.sim._model.value == "gaussian"
model = "constant"
params = {
"amplitude": 300,
"noise": "uniform",
"noise_multiplier": 1,
"hot_pixel_coords": [[0, 0], [50, 50]],
"hot_pixel_types": ["fluctuating", "constant"],
"hot_pixel_values": [2.0, 2.0],
}
sim = SimCamera(name="sim", device_manager=dm, sim_init={"model": model, "params": params})
assert sim.sim._model.value == model
assert sim.sim.params == params
def test_monitor_with_sim_init():
"""Test to see if the sim init parameters are passed to the device"""
dm = DMMock()
sim = SimMonitor(name="sim", device_manager=dm)
assert sim.sim._model._name == "constant"
model = "GaussianModel"
params = {
"amplitude": 500,
"center": 5,
"sigma": 4,
"noise": "uniform",
"noise_multiplier": 1,
"ref_motor": "samy",
}
sim = SimMonitor(name="sim", device_manager=dm, sim_init={"model": model, "params": params})
assert sim.sim._model._name == model.strip("Model").lower()
diff_keys = set(sim.sim.params.keys()) - set(params.keys())
for k in params:
assert sim.sim.params[k] == params[k]
def test_signal__init__(signal):
"""Test the BECProtocol class"""
assert isinstance(signal, BECDeviceProtocol)
@ -144,41 +184,41 @@ def test_monitor_readback(monitor, center):
"""Test the readback method of SimMonitor."""
motor_pos = 0
monitor.device_manager.add_device("samx", value=motor_pos)
for model_name in monitor.sim.sim_get_models():
monitor.sim.sim_select_model(model_name)
monitor.sim.sim_params["noise_multipler"] = 10
monitor.sim.sim_params["ref_motor"] = "samx"
if "c" in monitor.sim.sim_params:
monitor.sim.sim_params["c"] = center
elif "center" in monitor.sim.sim_params:
monitor.sim.sim_params["center"] = center
for model_name in monitor.sim.get_models():
monitor.sim.select_model(model_name)
monitor.sim.params["noise_multipler"] = 10
monitor.sim.params["ref_motor"] = "samx"
if "c" in monitor.sim.params:
monitor.sim.params["c"] = center
elif "center" in monitor.sim.params:
monitor.sim.params["center"] = center
assert isinstance(monitor.read()[monitor.name]["value"], monitor.BIT_DEPTH)
expected_value = monitor.sim._model.eval(monitor.sim._model_params, x=motor_pos)
print(expected_value, monitor.read()[monitor.name]["value"])
tolerance = (
monitor.sim.sim_params["noise_multipler"] + 1
monitor.sim.params["noise_multipler"] + 1
) # due to ceiling in calculation, but maximum +1int
assert np.isclose(
monitor.read()[monitor.name]["value"],
expected_value,
atol=monitor.sim.sim_params["noise_multipler"] + 1,
atol=monitor.sim.params["noise_multipler"] + 1,
)
@pytest.mark.parametrize("amplitude, noise_multiplier", [(0, 1), (100, 10), (1000, 50)])
def test_camera_readback(camera, amplitude, noise_multiplier):
"""Test the readback method of SimMonitor."""
for model_name in camera.sim.sim_get_models():
camera.sim.sim_select_model(model_name)
camera.sim.sim_params = {"noise_multiplier": noise_multiplier}
camera.sim.sim_params = {"amplitude": amplitude}
camera.sim.sim_params = {"noise": "poisson"}
for model_name in camera.sim.get_models():
camera.sim.select_model(model_name)
camera.sim.params = {"noise_multiplier": noise_multiplier}
camera.sim.params = {"amplitude": amplitude}
camera.sim.params = {"noise": "poisson"}
assert camera.image.get().shape == camera.SHAPE
assert isinstance(camera.image.get()[0, 0], camera.BIT_DEPTH)
camera.sim.sim_params = {"noise": "uniform"}
camera.sim.sim_params = {"hot_pixel_coords": []}
camera.sim.sim_params = {"hot_pixel_values": []}
camera.sim.sim_params = {"hot_pixel_types": []}
camera.sim.params = {"noise": "uniform"}
camera.sim.params = {"hot_pixel_coords": []}
camera.sim.params = {"hot_pixel_values": []}
camera.sim.params = {"hot_pixel_types": []}
assert camera.image.get().shape == camera.SHAPE
assert isinstance(camera.image.get()[0, 0], camera.BIT_DEPTH)
assert (camera.image.get() <= (amplitude + noise_multiplier + 1)).all()
@ -243,8 +283,9 @@ def test_h5proxy(h5proxy_fixture, camera):
{camera.name: {"signal_name": "image", "file_source": fname, "h5_entry": h5entry}}
)
camera._registered_proxies.update({h5proxy.name: camera.image.name})
camera.sim.sim_params = {"noise": "none", "noise_multiplier": 0}
camera.sim.params = {"noise": "none", "noise_multiplier": 0}
camera.scaninfo.sim_mode = True
# pylint: disable=no-member
camera.image_shape.set(data.shape[1:])
camera.stage()
img = camera.image.get()
@ -288,7 +329,7 @@ def test_slitproxy(slitproxy_fixture):
mock_camera.obj = camera
mock_samx.obj = samx
mock_proxy.obj = proxy
camera.sim.sim_params = {"noise": "none", "noise_multiplier": 0, "hot_pixel_values": [0, 0, 0]}
camera.sim.params = {"noise": "none", "noise_multiplier": 0, "hot_pixel_values": [0, 0, 0]}
samx.delay = 0
samx_pos = 0
samx.move(samx_pos)
@ -421,7 +462,7 @@ def test_async_mon_send_data_to_bec(async_monitor):
async_monitor.custom_prepare._send_data_to_bec()
dev_msg = messages.DeviceMessage(
signals={async_monitor.readback.name: async_monitor.data_buffer},
metadata={"async_update": async_monitor.async_update},
metadata={"async_update": async_monitor.async_update.get()},
)
call = [