fix: device sim params can be set through init

This commit is contained in:
appel_c 2024-07-03 08:22:31 +02:00
parent 69105332a4
commit f481c1f812
8 changed files with 141 additions and 114 deletions

View File

@ -63,7 +63,7 @@ class SimCameraSetup(CustomDetectorMixin):
self.parent.h5_writer.prepare( self.parent.h5_writer.prepare(
file_path=self.parent.filepath.get(), h5_entry="/entry/data/data" file_path=self.parent.filepath.get(), h5_entry="/entry/data/data"
) )
self.publish_file_location(done=False) self.publish_file_location(done=False, successful=False)
self.parent.stopped = False self.parent.stopped = False
def on_unstage(self) -> None: def on_unstage(self) -> None:
@ -122,13 +122,15 @@ class SimCamera(PSIDetectorBase):
def __init__( def __init__(
self, name, *, kind=None, parent=None, sim_init: dict = None, device_manager=None, **kwargs self, name, *, kind=None, parent=None, sim_init: dict = None, device_manager=None, **kwargs
): ):
self.init_sim_params = sim_init self.sim_init = sim_init
self._registered_proxies = {} self._registered_proxies = {}
self.sim = self.sim_cls(parent=self, **kwargs) self.sim = self.sim_cls(parent=self, **kwargs)
self.h5_writer = H5Writer() self.h5_writer = H5Writer()
super().__init__( super().__init__(
name=name, parent=parent, kind=kind, device_manager=device_manager, **kwargs name=name, parent=parent, kind=kind, device_manager=device_manager, **kwargs
) )
if self.sim_init:
self.sim.set_init(self.sim_init)
@property @property
def registered_proxies(self) -> None: def registered_proxies(self) -> None:

View File

@ -82,13 +82,13 @@ class SimulatedDataBase(ABC):
The class provides the following methods: The class provides the following methods:
- execute_simulation_method: execute a method from the simulated data class or reroute execution to device proxy class - execute_simulation_method: execute a method from the simulated data class or reroute execution to device proxy class
- sim_select_model: select the active simulation model - select_model: select the active simulation model
- sim_params: get the parameters for the active simulation mdoel - params: get the parameters for the active simulation mdoel
- sim_models: get the available simulation models - sim_models: get the available simulation models
- update_sim_state: update the simulated state of the device - update_sim_state: update the simulated state of the device
""" """
USER_ACCESS = ["sim_params", "sim_select_model", "sim_get_models", "sim_show_all"] USER_ACCESS = ["params", "select_model", "get_models", "show_all"]
def __init__(self, *args, parent=None, **kwargs) -> None: def __init__(self, *args, parent=None, **kwargs) -> None:
""" """
@ -122,7 +122,12 @@ class SimulatedDataBase(ABC):
return method(*args, **kwargs) return method(*args, **kwargs)
raise SimulatedDataException(f"Method {method} is not available for {self.parent.name}") raise SimulatedDataException(f"Method {method} is not available for {self.parent.name}")
# TODO remove after refactoring code in main
def sim_select_model(self, model: str) -> None: def sim_select_model(self, model: str) -> None:
"""Select the active simulation model."""
self.select_model(model)
def select_model(self, model: str) -> None:
""" """
Method to select the active simulation model. Method to select the active simulation model.
It will initiate the model_cls and parameters for the model. It will initiate the model_cls and parameters for the model.
@ -138,7 +143,7 @@ class SimulatedDataBase(ABC):
print(self._get_table_active_simulation()) print(self._get_table_active_simulation())
@property @property
def sim_params(self) -> dict: def params(self) -> dict:
""" """
Property that returns the parameters for the active simulation model. It can also Property that returns the parameters for the active simulation model. It can also
be used to set the parameters for the active simulation updating the parameters of the model. be used to set the parameters for the active simulation updating the parameters of the model.
@ -147,17 +152,17 @@ class SimulatedDataBase(ABC):
dict: Parameters for the active simulation model. dict: Parameters for the active simulation model.
The following example shows how to update the noise parameter of the current simulation. The following example shows how to update the noise parameter of the current simulation.
>>> dev.<device>.sim.sim_params = {"noise": "poisson"} >>> dev.<device>.sim.params = {"noise": "poisson"}
""" """
return self._params return self._params
@sim_params.setter @params.setter
def sim_params(self, params: dict): def params(self, params: dict):
""" """
Method to set the parameters for the active simulation model. Method to set the parameters for the active simulation model.
""" """
for k, v in params.items(): for k, v in params.items():
if k in self.sim_params: if k in self.params:
if k == "noise": if k == "noise":
self._params[k] = NoiseType(v) self._params[k] = NoiseType(v)
elif k == "hot_pixel_types": elif k == "hot_pixel_types":
@ -167,9 +172,9 @@ class SimulatedDataBase(ABC):
if isinstance(self._model, Model) and k in self._model_params: if isinstance(self._model, Model) and k in self._model_params:
self._model_params[k].value = v self._model_params[k].value = v
else: else:
raise SimulatedDataException(f"Parameter {k} not found in {self.sim_params}.") raise SimulatedDataException(f"Parameter {k} not found in {self.params}.")
def sim_get_models(self) -> list: def get_models(self) -> list:
""" """
Method to get the all available simulation models. Method to get the all available simulation models.
""" """
@ -221,7 +226,7 @@ class SimulatedDataBase(ABC):
table = PrettyTable() table = PrettyTable()
table.title = f"Currently active model: {self._model}" table.title = f"Currently active model: {self._model}"
table.field_names = ["Parameter", "Value", "Type"] table.field_names = ["Parameter", "Value", "Type"]
for k, v in self.sim_params.items(): for k, v in self.params.items():
table.add_row([k, f"{v}", f"{type(v)}"]) table.add_row([k, f"{v}", f"{type(v)}"])
table._min_width["Parameter"] = 25 if width > 75 else width // 3 table._min_width["Parameter"] = 25 if width > 75 else width // 3
table._min_width["Type"] = 25 if width > 75 else width // 3 table._min_width["Type"] = 25 if width > 75 else width // 3
@ -238,16 +243,16 @@ class SimulatedDataBase(ABC):
table.title = "Available methods within the simulation module" table.title = "Available methods within the simulation module"
table.field_names = ["Method", "Docstring"] table.field_names = ["Method", "Docstring"]
table.add_row([self.sim_get_models.__name__, f"{self.sim_get_models.__doc__}"]) table.add_row([self.get_models.__name__, f"{self.get_models.__doc__}"])
table.add_row([self.sim_select_model.__name__, self.sim_select_model.__doc__]) table.add_row([self.select_model.__name__, self.select_model.__doc__])
table.add_row(["sim_params", self.__class__.sim_params.__doc__]) table.add_row(["params", self.__class__.params.__doc__])
table.max_table_width = width table.max_table_width = width
table._min_table_width = width table._min_table_width = width
table.align["Docstring"] = "l" table.align["Docstring"] = "l"
return table return table
def sim_show_all(self): def show_all(self):
"""Returns a summary about the active simulation and available methods.""" """Returns a summary about the active simulation and available methods."""
width = 150 width = 150
print(self._get_table_active_simulation(width=width)) print(self._get_table_active_simulation(width=width))
@ -260,6 +265,15 @@ class SimulatedDataBase(ABC):
table._min_table_width = width table._min_table_width = width
print(table) print(table)
def set_init(self, sim_init: dict["model", "params"]) -> None:
"""Set the initial simulation parameters.
Args:
sim_init (dict["model"]): Dictionary to initiate parameters of the simulation.
"""
self.select_model(sim_init.get("model"))
self.params = sim_init.get("params", {})
class SimulatedPositioner(SimulatedDataBase): class SimulatedPositioner(SimulatedDataBase):
"""Simulated data class for a positioner.""" """Simulated data class for a positioner."""
@ -317,7 +331,7 @@ class SimulatedDataMonitor(SimulatedDataBase):
def _init_default(self) -> None: def _init_default(self) -> None:
"""Initialize the default parameters for the simulated data.""" """Initialize the default parameters for the simulated data."""
self.sim_select_model("ConstantModel") self.select_model("ConstantModel")
def get_model_cls(self, model: str) -> any: def get_model_cls(self, model: str) -> any:
"""Get the class for the active simulation model.""" """Get the class for the active simulation model."""
@ -411,14 +425,14 @@ class SimulatedDataMonitor(SimulatedDataBase):
Returns: Returns:
float: Value computed by the active model. float: Value computed by the active model.
""" """
mot_name = self.sim_params["ref_motor"] mot_name = self.params["ref_motor"]
if self.parent.device_manager and mot_name in self.parent.device_manager.devices: if self.parent.device_manager and mot_name in self.parent.device_manager.devices:
motor_pos = self.parent.device_manager.devices[mot_name].obj.read()[mot_name]["value"] motor_pos = self.parent.device_manager.devices[mot_name].obj.read()[mot_name]["value"]
else: else:
motor_pos = 0 motor_pos = 0
method = self._model method = self._model
value = int(method.eval(params=self._model_params, x=motor_pos)) value = int(method.eval(params=self._model_params, x=motor_pos))
return self._add_noise(value, self.sim_params["noise"], self.sim_params["noise_multiplier"]) return self._add_noise(value, self.params["noise"], self.params["noise_multiplier"])
def _add_noise(self, v: int, noise: NoiseType, noise_multiplier: float) -> int: def _add_noise(self, v: int, noise: NoiseType, noise_multiplier: float) -> int:
""" """
@ -478,8 +492,8 @@ class SimulatedDataWaveform(SimulatedDataMonitor):
size = size[0] if isinstance(size, tuple) else size size = size[0] if isinstance(size, tuple) else size
method = self._model method = self._model
value = method.eval(params=self._model_params, x=np.array(range(size))) value = method.eval(params=self._model_params, x=np.array(range(size)))
value *= self.sim_params["amplitude"] / np.max(value) value *= self.params["amplitude"] / np.max(value)
return self._add_noise(value, self.sim_params["noise"], self.sim_params["noise_multiplier"]) return self._add_noise(value, self.params["noise"], self.params["noise_multiplier"])
def _add_noise(self, v: np.ndarray, noise: NoiseType, noise_multiplier: float) -> np.ndarray: def _add_noise(self, v: np.ndarray, noise: NoiseType, noise_multiplier: float) -> np.ndarray:
"""Add noise to the simulated data. """Add noise to the simulated data.
@ -515,7 +529,7 @@ class SimulatedDataCamera(SimulatedDataBase):
Use the default model "Gaussian". Use the default model "Gaussian".
""" """
self.sim_select_model(SimulationType2D.GAUSSIAN) self.select_model(SimulationType2D.GAUSSIAN)
def init_2D_models(self) -> dict: def init_2D_models(self) -> dict:
""" """
@ -609,13 +623,13 @@ class SimulatedDataCamera(SimulatedDataBase):
"""Compute a return value for SimulationType2D constant.""" """Compute a return value for SimulationType2D constant."""
try: try:
shape = self.parent.image_shape.get() shape = self.parent.image_shape.get()
v = self.sim_params.get("amplitude") * np.ones(shape, dtype=np.float32) v = self.params.get("amplitude") * np.ones(shape, dtype=np.float32)
v = self._add_noise(v, self.sim_params["noise"], self.sim_params["noise_multiplier"]) v = self._add_noise(v, self.params["noise"], self.params["noise_multiplier"])
return self._add_hot_pixel( return self._add_hot_pixel(
v, v,
coords=self.sim_params["hot_pixel_coords"], coords=self.params["hot_pixel_coords"],
hot_pixel_types=self.sim_params["hot_pixel_types"], hot_pixel_types=self.params["hot_pixel_types"],
values=self.sim_params["hot_pixel_values"], values=self.params["hot_pixel_values"],
) )
except SimulatedDataException as exc: except SimulatedDataException as exc:
raise SimulatedDataException( raise SimulatedDataException(
@ -634,9 +648,9 @@ class SimulatedDataCamera(SimulatedDataBase):
""" """
try: try:
amp = self.sim_params.get("amplitude") amp = self.params.get("amplitude")
cov = self.sim_params.get("covariance") cov = self.params.get("covariance")
cen_off = self.sim_params.get("center_offset") cen_off = self.params.get("center_offset")
shape = self.sim_state[self.parent.image_shape.name]["value"] shape = self.sim_state[self.parent.image_shape.name]["value"]
pos, offset, cov, amp = self._prepare_params_gauss( pos, offset, cov, amp = self._prepare_params_gauss(
amp=amp, cov=cov, offset=cen_off, shape=shape amp=amp, cov=cov, offset=cen_off, shape=shape
@ -644,15 +658,13 @@ class SimulatedDataCamera(SimulatedDataBase):
v = self._compute_multivariate_gaussian(pos=pos, cen_off=offset, cov=cov, amp=amp) v = self._compute_multivariate_gaussian(pos=pos, cen_off=offset, cov=cov, amp=amp)
v = self._add_noise( v = self._add_noise(
v, v, noise=self.params["noise"], noise_multiplier=self.params["noise_multiplier"]
noise=self.sim_params["noise"],
noise_multiplier=self.sim_params["noise_multiplier"],
) )
return self._add_hot_pixel( return self._add_hot_pixel(
v, v,
coords=self.sim_params["hot_pixel_coords"], coords=self.params["hot_pixel_coords"],
hot_pixel_types=self.sim_params["hot_pixel_types"], hot_pixel_types=self.params["hot_pixel_types"],
values=self.sim_params["hot_pixel_values"], values=self.params["hot_pixel_values"],
) )
except SimulatedDataException as exc: except SimulatedDataException as exc:
raise SimulatedDataException( raise SimulatedDataException(

View File

@ -48,6 +48,7 @@ class SimFlyer(Device, FlyerInterface):
parent=None, parent=None,
kind=None, kind=None,
device_manager=None, device_manager=None,
sim_init: dict = None,
# TODO remove after refactoring config # TODO remove after refactoring config
delay: int = 1, delay: int = 1,
update_frequency: int = 100, update_frequency: int = 100,
@ -55,6 +56,7 @@ class SimFlyer(Device, FlyerInterface):
): ):
self.sim = self.sim_cls(parent=self, **kwargs) self.sim = self.sim_cls(parent=self, **kwargs)
self.sim_init = sim_init
self.precision = precision self.precision = precision
self.device_manager = device_manager self.device_manager = device_manager
self._registered_proxies = {} self._registered_proxies = {}
@ -62,6 +64,8 @@ class SimFlyer(Device, FlyerInterface):
super().__init__(name=name, parent=parent, kind=kind, **kwargs) super().__init__(name=name, parent=parent, kind=kind, **kwargs)
self.sim.sim_state[self.name] = self.sim.sim_state.pop(self.readback.name, None) self.sim.sim_state[self.name] = self.sim.sim_state.pop(self.readback.name, None)
self.readback.name = self.name self.readback.name = self.name
if self.sim_init:
self.sim.set_init(self.sim_init)
@property @property
def registered_proxies(self) -> None: def registered_proxies(self) -> None:

View File

@ -2,6 +2,9 @@ from abc import ABC, abstractmethod
from collections import defaultdict from collections import defaultdict
import h5py import h5py
# Necessary import to allow h5py to open compressed h5files.
# pylint: disable=unused-import
import hdf5plugin # noqa: F401 import hdf5plugin # noqa: F401
import numpy as np import numpy as np
from ophyd import Kind, Staged from ophyd import Kind, Staged
@ -87,20 +90,7 @@ class SlitProxy(DeviceProxy):
To update for instance the pixel_size directly, you can directly access the DeviceConfig via To update for instance the pixel_size directly, you can directly access the DeviceConfig via
`dev.eiger.get_device_config()` or update it `dev.eiger.get_device_config({'eiger' : {'pixel_size': 0.1}})` `dev.eiger.get_device_config()` or update it `dev.eiger.get_device_config({'eiger' : {'pixel_size': 0.1}})`
slit_sim: An example for the configuration of this is device is in ophyd_devices.configs.ophyd_devices_simulation.yaml
readoutPriority: baseline
deviceClass: SlitProxy
deviceConfig:
eiger:
signal_name: image
center_offset: [0, 0] # [x,y]
covariance: [[1000, 500], [200, 1000]] # [[x,x],[y,y]]
pixel_size: 0.01
ref_motors: [samx, samy]
slit_width: [1, 1]
motor_dir: [0, 1] # x:0 , y:1, z:2 coordinates
enabled: true
readOnly: false
""" """
USER_ACCESS = ["enabled", "lookup", "help"] USER_ACCESS = ["enabled", "lookup", "help"]
@ -126,7 +116,7 @@ class SlitProxy(DeviceProxy):
np.ndarray: Lookup table for the simulated camera. np.ndarray: Lookup table for the simulated camera.
""" """
device_obj = self.device_manager.devices.get(device_name).obj device_obj = self.device_manager.devices.get(device_name).obj
params = device_obj.sim.sim_params params = device_obj.sim.params
shape = device_obj.image_shape.get() shape = device_obj.image_shape.get()
params.update( params.update(
{ {
@ -197,18 +187,10 @@ class SlitProxy(DeviceProxy):
class H5ImageReplayProxy(DeviceProxy): class H5ImageReplayProxy(DeviceProxy):
"""This Proxy class can be used to replay images from an h5 file. """This Proxy class can be used to replay images from an h5 file.
If the number of requested images is larger than the number of available iamges, the images will be replayed from the beginning. If the number of requested images is larger than the number of available iamges,
the images will be replayed from the beginning.
h5_image_sim: An example for the configuration of this is device is in ophyd_devices.configs.ophyd_devices_simulation.yaml
readoutPriority: baseline
deviceClass: H5ImageReplayProxy
deviceConfig:
eiger:
signal_name: image
file_source: /path/to/h5file.h5
h5_entry: /entry/data
enabled: true
readOnly: false
""" """
USER_ACCESS = ["file_source", "h5_entry"] USER_ACCESS = ["file_source", "h5_entry"]

View File

@ -1,12 +1,9 @@
from typing import Literal
import numpy as np import numpy as np
from bec_lib import messages from bec_lib import messages
from bec_lib.endpoints import MessageEndpoints from bec_lib.endpoints import MessageEndpoints
from bec_lib.logger import bec_logger from bec_lib.logger import bec_logger
from ophyd import Component as Cpt from ophyd import Component as Cpt
from ophyd import Device, Kind from ophyd import Device, Kind
from typeguard import typechecked
from ophyd_devices.interfaces.base_classes.psi_detector_base import ( from ophyd_devices.interfaces.base_classes.psi_detector_base import (
CustomDetectorMixin, CustomDetectorMixin,
@ -60,7 +57,7 @@ class SimMonitor(Device):
**kwargs, **kwargs,
): ):
self.precision = precision self.precision = precision
self.init_sim_params = sim_init self.sim_init = sim_init
self.device_manager = device_manager self.device_manager = device_manager
self.sim = self.sim_cls(parent=self, **kwargs) self.sim = self.sim_cls(parent=self, **kwargs)
self._registered_proxies = {} self._registered_proxies = {}
@ -68,6 +65,8 @@ class SimMonitor(Device):
super().__init__(name=name, parent=parent, kind=kind, **kwargs) super().__init__(name=name, parent=parent, kind=kind, **kwargs)
self.sim.sim_state[self.name] = self.sim.sim_state.pop(self.readback.name, None) self.sim.sim_state[self.name] = self.sim.sim_state.pop(self.readback.name, None)
self.readback.name = self.name self.readback.name = self.name
if self.sim_init:
self.sim.set_init(self.sim_init)
@property @property
def registered_proxies(self) -> None: def registered_proxies(self) -> None:
@ -112,7 +111,7 @@ class SimMonitorAsyncPrepare(CustomDetectorMixin):
if self.parent.scaninfo.scan_msg is None: if self.parent.scaninfo.scan_msg is None:
return return
metadata = self.parent.scaninfo.scan_msg.metadata metadata = self.parent.scaninfo.scan_msg.metadata
metadata.update({"async_update": self.parent.async_update}) metadata.update({"async_update": self.parent.async_update.get()})
msg = messages.DeviceMessage( msg = messages.DeviceMessage(
signals={self.parent.readback.name: self.parent.data_buffer}, signals={self.parent.readback.name: self.parent.data_buffer},
@ -163,6 +162,7 @@ class SimMonitorAsync(PSIDetectorBase):
readback = Cpt(ReadOnlySignal, value=BIT_DEPTH(0), kind=Kind.hinted, compute_readback=True) readback = Cpt(ReadOnlySignal, value=BIT_DEPTH(0), kind=Kind.hinted, compute_readback=True)
current_trigger = Cpt(SetableSignal, value=BIT_DEPTH(0), kind=Kind.config) current_trigger = Cpt(SetableSignal, value=BIT_DEPTH(0), kind=Kind.config)
async_update = Cpt(SetableSignal, value="extend", kind=Kind.config)
SUB_READBACK = "readback" SUB_READBACK = "readback"
SUB_PROGRESS = "progress" SUB_PROGRESS = "progress"
@ -171,7 +171,7 @@ class SimMonitorAsync(PSIDetectorBase):
def __init__( def __init__(
self, name, *, sim_init: dict = None, parent=None, kind=None, device_manager=None, **kwargs self, name, *, sim_init: dict = None, parent=None, kind=None, device_manager=None, **kwargs
): ):
self.init_sim_params = sim_init self.sim_init = sim_init
self.device_manager = device_manager self.device_manager = device_manager
self.sim = self.sim_cls(parent=self, **kwargs) self.sim = self.sim_cls(parent=self, **kwargs)
self._registered_proxies = {} self._registered_proxies = {}
@ -182,7 +182,8 @@ class SimMonitorAsync(PSIDetectorBase):
self.sim.sim_state[self.name] = self.sim.sim_state.pop(self.readback.name, None) self.sim.sim_state[self.name] = self.sim.sim_state.pop(self.readback.name, None)
self.readback.name = self.name self.readback.name = self.name
self._data_buffer = {"value": [], "timestamp": []} self._data_buffer = {"value": [], "timestamp": []}
self._async_update = "extend" if self.sim_init:
self.sim.set_init(self.sim_init)
@property @property
def data_buffer(self) -> list: def data_buffer(self) -> list:
@ -193,18 +194,3 @@ class SimMonitorAsync(PSIDetectorBase):
def registered_proxies(self) -> None: def registered_proxies(self) -> None:
"""Dictionary of registered signal_names and proxies.""" """Dictionary of registered signal_names and proxies."""
return self._registered_proxies return self._registered_proxies
@property
def async_update(self) -> str:
"""Update method for the asynchronous monitor."""
return self._async_update
@async_update.setter
@typechecked
def async_update(self, value: Literal["extend", "append"]) -> None:
"""Set the update method for the asynchronous monitor.
Args:
value (str): Can only be "extend" or "append".
"""
self._async_update = value

View File

@ -78,7 +78,7 @@ class SimPositioner(Device, PositionerBase):
self.delay = delay self.delay = delay
self.device_manager = device_manager self.device_manager = device_manager
self.precision = precision self.precision = precision
self.init_sim_params = sim_init self.sim_init = sim_init
self._registered_proxies = {} self._registered_proxies = {}
self.update_frequency = update_frequency self.update_frequency = update_frequency
@ -94,11 +94,8 @@ class SimPositioner(Device, PositionerBase):
assert len(limits) == 2 assert len(limits) == 2
self.low_limit_travel.put(limits[0]) self.low_limit_travel.put(limits[0])
self.high_limit_travel.put(limits[1]) self.high_limit_travel.put(limits[1])
if self.sim_init:
# @property self.sim.set_init(self.sim_init)
# def connected(self):
# """Return the connected state of the simulated device."""
# return self.dummy_controller.connected
@property @property
def limits(self): def limits(self):

View File

@ -7,6 +7,7 @@ from ophyd import Component as Cpt
from ophyd import Device, DeviceStatus, Kind from ophyd import Device, DeviceStatus, Kind
from ophyd_devices.sim.sim_data import SimulatedDataWaveform from ophyd_devices.sim.sim_data import SimulatedDataWaveform
from ophyd_devices.sim.sim_exception import DeviceStop
from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal from ophyd_devices.sim.sim_signals import ReadOnlySignal, SetableSignal
from ophyd_devices.utils.bec_scaninfo_mixin import BecScaninfoMixin from ophyd_devices.utils.bec_scaninfo_mixin import BecScaninfoMixin
@ -60,7 +61,7 @@ class SimWaveform(Device):
self, name, *, kind=None, parent=None, sim_init: dict = None, device_manager=None, **kwargs self, name, *, kind=None, parent=None, sim_init: dict = None, device_manager=None, **kwargs
): ):
self.device_manager = device_manager self.device_manager = device_manager
self.init_sim_params = sim_init self.sim_init = sim_init
self._registered_proxies = {} self._registered_proxies = {}
self.sim = self.sim_cls(parent=self, **kwargs) self.sim = self.sim_cls(parent=self, **kwargs)
@ -69,6 +70,8 @@ class SimWaveform(Device):
self._staged = False self._staged = False
self.scaninfo = None self.scaninfo = None
self._update_scaninfo() self._update_scaninfo()
if self.sim_init:
self.sim.set_init(self.sim_init)
@property @property
def registered_proxies(self) -> None: def registered_proxies(self) -> None:

View File

@ -96,6 +96,46 @@ def flyer(name="flyer"):
yield fly yield fly
def test_camera_with_sim_init():
"""Test to see if the sim init parameters are passed to the device"""
dm = DMMock()
sim = SimCamera(name="sim", device_manager=dm)
assert sim.sim._model.value == "gaussian"
model = "constant"
params = {
"amplitude": 300,
"noise": "uniform",
"noise_multiplier": 1,
"hot_pixel_coords": [[0, 0], [50, 50]],
"hot_pixel_types": ["fluctuating", "constant"],
"hot_pixel_values": [2.0, 2.0],
}
sim = SimCamera(name="sim", device_manager=dm, sim_init={"model": model, "params": params})
assert sim.sim._model.value == model
assert sim.sim.params == params
def test_monitor_with_sim_init():
"""Test to see if the sim init parameters are passed to the device"""
dm = DMMock()
sim = SimMonitor(name="sim", device_manager=dm)
assert sim.sim._model._name == "constant"
model = "GaussianModel"
params = {
"amplitude": 500,
"center": 5,
"sigma": 4,
"noise": "uniform",
"noise_multiplier": 1,
"ref_motor": "samy",
}
sim = SimMonitor(name="sim", device_manager=dm, sim_init={"model": model, "params": params})
assert sim.sim._model._name == model.strip("Model").lower()
diff_keys = set(sim.sim.params.keys()) - set(params.keys())
for k in params:
assert sim.sim.params[k] == params[k]
def test_signal__init__(signal): def test_signal__init__(signal):
"""Test the BECProtocol class""" """Test the BECProtocol class"""
assert isinstance(signal, BECDeviceProtocol) assert isinstance(signal, BECDeviceProtocol)
@ -144,41 +184,41 @@ def test_monitor_readback(monitor, center):
"""Test the readback method of SimMonitor.""" """Test the readback method of SimMonitor."""
motor_pos = 0 motor_pos = 0
monitor.device_manager.add_device("samx", value=motor_pos) monitor.device_manager.add_device("samx", value=motor_pos)
for model_name in monitor.sim.sim_get_models(): for model_name in monitor.sim.get_models():
monitor.sim.sim_select_model(model_name) monitor.sim.select_model(model_name)
monitor.sim.sim_params["noise_multipler"] = 10 monitor.sim.params["noise_multipler"] = 10
monitor.sim.sim_params["ref_motor"] = "samx" monitor.sim.params["ref_motor"] = "samx"
if "c" in monitor.sim.sim_params: if "c" in monitor.sim.params:
monitor.sim.sim_params["c"] = center monitor.sim.params["c"] = center
elif "center" in monitor.sim.sim_params: elif "center" in monitor.sim.params:
monitor.sim.sim_params["center"] = center monitor.sim.params["center"] = center
assert isinstance(monitor.read()[monitor.name]["value"], monitor.BIT_DEPTH) assert isinstance(monitor.read()[monitor.name]["value"], monitor.BIT_DEPTH)
expected_value = monitor.sim._model.eval(monitor.sim._model_params, x=motor_pos) expected_value = monitor.sim._model.eval(monitor.sim._model_params, x=motor_pos)
print(expected_value, monitor.read()[monitor.name]["value"]) print(expected_value, monitor.read()[monitor.name]["value"])
tolerance = ( tolerance = (
monitor.sim.sim_params["noise_multipler"] + 1 monitor.sim.params["noise_multipler"] + 1
) # due to ceiling in calculation, but maximum +1int ) # due to ceiling in calculation, but maximum +1int
assert np.isclose( assert np.isclose(
monitor.read()[monitor.name]["value"], monitor.read()[monitor.name]["value"],
expected_value, expected_value,
atol=monitor.sim.sim_params["noise_multipler"] + 1, atol=monitor.sim.params["noise_multipler"] + 1,
) )
@pytest.mark.parametrize("amplitude, noise_multiplier", [(0, 1), (100, 10), (1000, 50)]) @pytest.mark.parametrize("amplitude, noise_multiplier", [(0, 1), (100, 10), (1000, 50)])
def test_camera_readback(camera, amplitude, noise_multiplier): def test_camera_readback(camera, amplitude, noise_multiplier):
"""Test the readback method of SimMonitor.""" """Test the readback method of SimMonitor."""
for model_name in camera.sim.sim_get_models(): for model_name in camera.sim.get_models():
camera.sim.sim_select_model(model_name) camera.sim.select_model(model_name)
camera.sim.sim_params = {"noise_multiplier": noise_multiplier} camera.sim.params = {"noise_multiplier": noise_multiplier}
camera.sim.sim_params = {"amplitude": amplitude} camera.sim.params = {"amplitude": amplitude}
camera.sim.sim_params = {"noise": "poisson"} camera.sim.params = {"noise": "poisson"}
assert camera.image.get().shape == camera.SHAPE assert camera.image.get().shape == camera.SHAPE
assert isinstance(camera.image.get()[0, 0], camera.BIT_DEPTH) assert isinstance(camera.image.get()[0, 0], camera.BIT_DEPTH)
camera.sim.sim_params = {"noise": "uniform"} camera.sim.params = {"noise": "uniform"}
camera.sim.sim_params = {"hot_pixel_coords": []} camera.sim.params = {"hot_pixel_coords": []}
camera.sim.sim_params = {"hot_pixel_values": []} camera.sim.params = {"hot_pixel_values": []}
camera.sim.sim_params = {"hot_pixel_types": []} camera.sim.params = {"hot_pixel_types": []}
assert camera.image.get().shape == camera.SHAPE assert camera.image.get().shape == camera.SHAPE
assert isinstance(camera.image.get()[0, 0], camera.BIT_DEPTH) assert isinstance(camera.image.get()[0, 0], camera.BIT_DEPTH)
assert (camera.image.get() <= (amplitude + noise_multiplier + 1)).all() assert (camera.image.get() <= (amplitude + noise_multiplier + 1)).all()
@ -243,8 +283,9 @@ def test_h5proxy(h5proxy_fixture, camera):
{camera.name: {"signal_name": "image", "file_source": fname, "h5_entry": h5entry}} {camera.name: {"signal_name": "image", "file_source": fname, "h5_entry": h5entry}}
) )
camera._registered_proxies.update({h5proxy.name: camera.image.name}) camera._registered_proxies.update({h5proxy.name: camera.image.name})
camera.sim.sim_params = {"noise": "none", "noise_multiplier": 0} camera.sim.params = {"noise": "none", "noise_multiplier": 0}
camera.scaninfo.sim_mode = True camera.scaninfo.sim_mode = True
# pylint: disable=no-member
camera.image_shape.set(data.shape[1:]) camera.image_shape.set(data.shape[1:])
camera.stage() camera.stage()
img = camera.image.get() img = camera.image.get()
@ -288,7 +329,7 @@ def test_slitproxy(slitproxy_fixture):
mock_camera.obj = camera mock_camera.obj = camera
mock_samx.obj = samx mock_samx.obj = samx
mock_proxy.obj = proxy mock_proxy.obj = proxy
camera.sim.sim_params = {"noise": "none", "noise_multiplier": 0, "hot_pixel_values": [0, 0, 0]} camera.sim.params = {"noise": "none", "noise_multiplier": 0, "hot_pixel_values": [0, 0, 0]}
samx.delay = 0 samx.delay = 0
samx_pos = 0 samx_pos = 0
samx.move(samx_pos) samx.move(samx_pos)
@ -421,7 +462,7 @@ def test_async_mon_send_data_to_bec(async_monitor):
async_monitor.custom_prepare._send_data_to_bec() async_monitor.custom_prepare._send_data_to_bec()
dev_msg = messages.DeviceMessage( dev_msg = messages.DeviceMessage(
signals={async_monitor.readback.name: async_monitor.data_buffer}, signals={async_monitor.readback.name: async_monitor.data_buffer},
metadata={"async_update": async_monitor.async_update}, metadata={"async_update": async_monitor.async_update.get()},
) )
call = [ call = [