766 lines
28 KiB
Python
766 lines
28 KiB
Python
from __future__ import annotations
|
|
|
|
import enum
|
|
import inspect
|
|
import time as ttime
|
|
from abc import ABC, abstractmethod
|
|
from collections import defaultdict
|
|
from copy import deepcopy
|
|
|
|
import numpy as np
|
|
from bec_lib import bec_logger
|
|
from lmfit import Model, models
|
|
from prettytable import PrettyTable
|
|
|
|
logger = bec_logger.logger
|
|
|
|
|
|
class SimulatedDataException(Exception):
|
|
"""Exception raised when there is an issue with the simulated data."""
|
|
|
|
|
|
class SimulationType2D(str, enum.Enum):
|
|
"""Type of simulation to steer simulated data."""
|
|
|
|
CONSTANT = "constant"
|
|
GAUSSIAN = "gaussian"
|
|
|
|
|
|
class NoiseType(str, enum.Enum):
|
|
"""Type of noise to add to simulated data."""
|
|
|
|
NONE = "none"
|
|
UNIFORM = "uniform"
|
|
POISSON = "poisson"
|
|
|
|
|
|
class HotPixelType(str, enum.Enum):
|
|
"""Type of hot pixel to add to simulated data."""
|
|
|
|
CONSTANT = "constant"
|
|
FLUCTUATING = "fluctuating"
|
|
|
|
|
|
DEFAULT_PARAMS_LMFIT = {
|
|
"c0": 1,
|
|
"c1": 1,
|
|
"c2": 1,
|
|
"c3": 1,
|
|
"c4": 1,
|
|
"c": 100,
|
|
"amplitude": 100,
|
|
"center": 0,
|
|
"sigma": 1,
|
|
}
|
|
|
|
DEFAULT_PARAMS_NOISE = {"noise": NoiseType.UNIFORM, "noise_multiplier": 10}
|
|
|
|
DEFAULT_PARAMS_MOTOR = {"ref_motor": "samx"}
|
|
|
|
DEFAULT_PARAMS_CAMERA_GAUSSIAN = {
|
|
"amplitude": 100,
|
|
"center_offset": np.array([0, 0]),
|
|
"covariance": np.array([[400, 100], [100, 400]]),
|
|
}
|
|
|
|
DEFAULT_PARAMS_CAMERA_CONSTANT = {"amplitude": 100}
|
|
|
|
DEFAULT_PARAMS_HOT_PIXEL = {
|
|
"hot_pixel_coords": np.array([[24, 24], [50, 20], [4, 40]]),
|
|
"hot_pixel_types": [HotPixelType.FLUCTUATING, HotPixelType.CONSTANT, HotPixelType.FLUCTUATING],
|
|
"hot_pixel_values": np.array([1e4, 1e6, 1e4]),
|
|
}
|
|
|
|
|
|
class SimulatedDataBase(ABC):
|
|
"""Abstract base class for simulated data.
|
|
|
|
This class should be subclassed to implement the simulated data for a specific device.
|
|
It provides the basic functionality to set and get data from the simulated data class
|
|
|
|
---------------------
|
|
The class provides the following methods:
|
|
|
|
- execute_simulation_method: execute a method from the simulated data class or reroute execution to device proxy class
|
|
- select_model: select the active simulation model
|
|
- params: get the parameters for the active simulation mdoel
|
|
- sim_models: get the available simulation models
|
|
- update_sim_state: update the simulated state of the device
|
|
"""
|
|
|
|
USER_ACCESS = ["params", "select_model", "get_models", "show_all"]
|
|
|
|
def __init__(self, *args, parent=None, **kwargs) -> None:
|
|
"""
|
|
Note:
|
|
self._model_params duplicates parameters from _params that are solely relevant for the model used.
|
|
This facilitates easier and faster access for computing the simulated state using the lmfit package.
|
|
"""
|
|
self.parent = parent
|
|
self.sim_state = defaultdict(dict)
|
|
self.registered_proxies = getattr(self.parent, "registered_proxies", {})
|
|
self._model = {}
|
|
self._model_params = None
|
|
self._params = {}
|
|
|
|
def execute_simulation_method(self, *args, method=None, signal_name: str = "", **kwargs) -> any:
|
|
"""
|
|
Execute either the provided method or reroutes the method execution
|
|
to a device proxy in case it is registered in self.parentregistered_proxies.
|
|
"""
|
|
if self.registered_proxies and self.parent.device_manager:
|
|
for proxy_name, signal in self.registered_proxies.items():
|
|
if signal == signal_name or f"{self.parent.name}_{signal}" == signal_name:
|
|
sim_proxy = self.parent.device_manager.devices.get(proxy_name, None)
|
|
if sim_proxy and sim_proxy.enabled is True:
|
|
method = sim_proxy.obj.lookup[self.parent.name]["method"]
|
|
args = sim_proxy.obj.lookup[self.parent.name]["args"]
|
|
kwargs = sim_proxy.obj.lookup[self.parent.name]["kwargs"]
|
|
break
|
|
|
|
if method is not None:
|
|
return method(*args, **kwargs)
|
|
raise SimulatedDataException(f"Method {method} is not available for {self.parent.name}")
|
|
|
|
def select_model(self, model: str) -> None:
|
|
"""
|
|
Method to select the active simulation model.
|
|
It will initiate the model_cls and parameters for the model.
|
|
|
|
Args:
|
|
model (str): Name of the simulation model to select.
|
|
|
|
"""
|
|
model_cls = self.get_model_cls(model)
|
|
self._model = model_cls() if callable(model_cls) else model_cls
|
|
self._params = self.get_params_for_model_cls()
|
|
self._params.update(self._get_additional_params())
|
|
|
|
@property
|
|
def params(self) -> dict:
|
|
"""
|
|
Property that returns the parameters for the active simulation model. It can also
|
|
be used to set the parameters for the active simulation updating the parameters of the model.
|
|
|
|
Returns:
|
|
dict: Parameters for the active simulation model.
|
|
|
|
The following example shows how to update the noise parameter of the current simulation.
|
|
>>> dev.<device>.sim.params = {"noise": "poisson"}
|
|
"""
|
|
return self._params
|
|
|
|
@params.setter
|
|
def params(self, params: dict):
|
|
"""
|
|
Method to set the parameters for the active simulation model.
|
|
"""
|
|
for k, v in params.items():
|
|
if k in self.params:
|
|
if k == "noise":
|
|
self._params[k] = NoiseType(v)
|
|
elif k == "hot_pixel_types":
|
|
self._params[k] = [HotPixelType(entry) for entry in v]
|
|
else:
|
|
self._params[k] = v
|
|
if isinstance(self._model, Model) and k in self._model_params:
|
|
self._model_params[k].value = v
|
|
else:
|
|
raise SimulatedDataException(f"Parameter {k} not found in {self.params}.")
|
|
|
|
def get_models(self) -> list:
|
|
"""
|
|
Method to get the all available simulation models.
|
|
"""
|
|
return self.get_all_sim_models()
|
|
|
|
def update_sim_state(self, signal_name: str, value: any) -> None:
|
|
"""Update the simulated state of the device.
|
|
|
|
Args:
|
|
signal_name (str): Name of the signal to update.
|
|
value (any): Value to update in the simulated state.
|
|
"""
|
|
self.sim_state[signal_name]["value"] = value
|
|
self.sim_state[signal_name]["timestamp"] = ttime.time()
|
|
|
|
@abstractmethod
|
|
def _get_additional_params(self) -> dict:
|
|
"""Initialize the default parameters for the noise."""
|
|
|
|
@abstractmethod
|
|
def get_model_cls(self, model: str) -> any:
|
|
"""
|
|
Method to get the class for the active simulation model_cls
|
|
"""
|
|
|
|
@abstractmethod
|
|
def get_params_for_model_cls(self) -> dict:
|
|
"""
|
|
Method to get the parameters for the active simulation model.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def get_all_sim_models(self) -> list[str]:
|
|
"""
|
|
Method to get all names from the available simulation models.
|
|
|
|
Returns:
|
|
list: List of available simulation models.
|
|
"""
|
|
|
|
@abstractmethod
|
|
def compute_sim_state(self, signal_name: str, compute_readback: bool) -> None:
|
|
"""
|
|
Method to compute the simulated state of the device.
|
|
"""
|
|
|
|
def _get_table_active_simulation(self, width: int = 140) -> PrettyTable:
|
|
"""Return a table with the active simulation model and parameters."""
|
|
table = PrettyTable()
|
|
table.title = f"Currently active model: {self._model}"
|
|
table.field_names = ["Parameter", "Value", "Type"]
|
|
for k, v in self.params.items():
|
|
table.add_row([k, f"{v}", f"{type(v)}"])
|
|
table._min_width["Parameter"] = 25 if width > 75 else width // 3
|
|
table._min_width["Type"] = 25 if width > 75 else width // 3
|
|
table.max_table_width = width
|
|
table._min_table_width = width
|
|
|
|
return table
|
|
|
|
def _get_table_method_information(self, width: int = 140) -> PrettyTable:
|
|
"""Return a table with the information about methods."""
|
|
table = PrettyTable()
|
|
table.max_width["Value"] = 120
|
|
table.hrules = 1
|
|
table.title = "Available methods within the simulation module"
|
|
table.field_names = ["Method", "Docstring"]
|
|
|
|
table.add_row([self.get_models.__name__, f"{self.get_models.__doc__}"])
|
|
table.add_row([self.select_model.__name__, self.select_model.__doc__])
|
|
table.add_row(["params", self.__class__.params.__doc__])
|
|
table.max_table_width = width
|
|
table._min_table_width = width
|
|
table.align["Docstring"] = "l"
|
|
|
|
return table
|
|
|
|
def show_all(self):
|
|
"""Returns a summary about the active simulation and available methods."""
|
|
width = 150
|
|
print(self._get_table_active_simulation(width=width))
|
|
print(self._get_table_method_information(width=width))
|
|
table = PrettyTable()
|
|
table.title = "Simulation module for current device"
|
|
table.field_names = ["All available models"]
|
|
table.add_row([", ".join(self.get_all_sim_models())])
|
|
table.max_table_width = width
|
|
table._min_table_width = width
|
|
print(table)
|
|
|
|
def set_init(self, sim_init: dict["model", "params"]) -> None:
|
|
"""Set the initial simulation parameters.
|
|
|
|
Args:
|
|
sim_init (dict["model"]): Dictionary to initiate parameters of the simulation.
|
|
"""
|
|
self.select_model(sim_init.get("model"))
|
|
self.params = sim_init.get("params", {})
|
|
|
|
|
|
class SimulatedPositioner(SimulatedDataBase):
|
|
"""Simulated data class for a positioner."""
|
|
|
|
def _init_default_additional_params(self) -> None:
|
|
"""No need to init additional parameters for Positioner."""
|
|
|
|
def get_model_cls(self, model: str) -> any:
|
|
"""For the simulated positioners, no simulation models are currently implemented."""
|
|
return None
|
|
|
|
def get_params_for_model_cls(self) -> dict:
|
|
"""For the simulated positioners, no simulation models are currently implemented."""
|
|
return {}
|
|
|
|
def get_all_sim_models(self) -> list[str]:
|
|
"""
|
|
For the simulated positioners, no simulation models are currently implemented.
|
|
|
|
Returns:
|
|
list: List of available simulation models.
|
|
"""
|
|
return []
|
|
|
|
def _get_additional_params(self) -> dict:
|
|
"""No need to add additional parameters for Positioner."""
|
|
return {}
|
|
|
|
def compute_sim_state(self, signal_name: str, compute_readback: bool) -> None:
|
|
"""
|
|
For the simulated positioners, a computed signal is currently not used.
|
|
The position is updated by the parent device, and readback/setpoint values
|
|
have a jitter/tolerance introduced directly in the parent class (SimPositioner).
|
|
"""
|
|
self.sim_state[signal_name].update({"timestamp": ttime.time()})
|
|
if compute_readback:
|
|
method = None
|
|
value = self.execute_simulation_method(method=method, signal_name=signal_name)
|
|
self.update_sim_state(signal_name, value)
|
|
|
|
|
|
class SimulatedDataMonitor(SimulatedDataBase):
|
|
"""Simulated data class for a monitor."""
|
|
|
|
def __init__(self, *args, parent=None, **kwargs) -> None:
|
|
self._model_lookup = self.init_lmfit_models()
|
|
super().__init__(*args, parent=parent, **kwargs)
|
|
self.bit_depth = self.parent.BIT_DEPTH
|
|
self._init_default()
|
|
|
|
def _get_additional_params(self) -> None:
|
|
params = deepcopy(DEFAULT_PARAMS_NOISE)
|
|
params.update(deepcopy(DEFAULT_PARAMS_MOTOR))
|
|
return params
|
|
|
|
def _init_default(self) -> None:
|
|
"""Initialize the default parameters for the simulated data."""
|
|
models = self.get_all_sim_models()
|
|
if "ConstantModel" in models:
|
|
self.select_model("ConstantModel")
|
|
else:
|
|
self.select_model(models[0])
|
|
|
|
def get_model_cls(self, model: str) -> any:
|
|
"""Get the class for the active simulation model."""
|
|
if model not in self._model_lookup:
|
|
raise SimulatedDataException(f"Model {model} not found in {self._model_lookup.keys()}.")
|
|
return self._model_lookup[model]
|
|
|
|
def get_all_sim_models(self) -> list[str]:
|
|
"""
|
|
Method to get all names from the available simulation models from the lmfit.models pool.
|
|
|
|
Returns:
|
|
list: List of available simulation models.
|
|
"""
|
|
return list(self._model_lookup.keys())
|
|
|
|
def get_params_for_model_cls(self) -> dict:
|
|
"""Get the parameters for the active simulation model.
|
|
|
|
Check if default parameters are available for lmfit parameters.
|
|
|
|
Args:
|
|
sim_model (str): Name of the simulation model.
|
|
Returns:
|
|
dict: {name: value} for the active simulation model.
|
|
"""
|
|
rtr = {}
|
|
params = self._model.make_params()
|
|
for name, parameter in params.items():
|
|
if name in DEFAULT_PARAMS_LMFIT:
|
|
rtr[name] = DEFAULT_PARAMS_LMFIT[name]
|
|
parameter.value = rtr[name]
|
|
else:
|
|
if not any([np.isnan(parameter.value), np.isinf(parameter.value)]):
|
|
rtr[name] = parameter.value
|
|
else:
|
|
rtr[name] = 1
|
|
parameter.value = 1
|
|
self._model_params = params
|
|
return rtr
|
|
|
|
def model_lookup(self):
|
|
"""Get available models from lmfit.models."""
|
|
return self._model_lookup
|
|
|
|
def init_lmfit_models(self) -> dict:
|
|
"""
|
|
Get available models from lmfit.models.
|
|
|
|
Exclude Gaussian2dModel, ExpressionModel, Model, SplineModel.
|
|
|
|
Returns:
|
|
dictionary of model name : model class pairs for available models from LMFit.
|
|
"""
|
|
model_lookup = {}
|
|
for name, model_cls in inspect.getmembers(models):
|
|
try:
|
|
is_model = issubclass(model_cls, Model)
|
|
except TypeError:
|
|
is_model = False
|
|
if is_model and name not in [
|
|
"ComplexConstantModel",
|
|
"Gaussian2dModel",
|
|
"ExpressionModel",
|
|
"Model",
|
|
"SplineModel",
|
|
]:
|
|
model_lookup[name] = model_cls
|
|
|
|
return model_lookup
|
|
|
|
def compute_sim_state(self, signal_name: str, compute_readback: bool) -> None:
|
|
"""Update the simulated state of the device.
|
|
|
|
It will update the value in self.sim_state with the value computed by
|
|
the chosen simulation type.
|
|
|
|
Args:
|
|
signal_name (str): Name of the signal to update.
|
|
"""
|
|
if compute_readback:
|
|
method = self._compute
|
|
value = self.execute_simulation_method(method=method, signal_name=signal_name)
|
|
value = self.bit_depth(np.max(value, 0))
|
|
self.update_sim_state(signal_name, value)
|
|
|
|
def _compute(self, *args, **kwargs) -> int:
|
|
"""
|
|
Compute the return value for given motor position and active model.
|
|
|
|
Returns:
|
|
float: Value computed by the active model.
|
|
"""
|
|
mot_name = self.params["ref_motor"]
|
|
if self.parent.device_manager and mot_name in self.parent.device_manager.devices:
|
|
motor_pos = self.parent.device_manager.devices[mot_name].obj.read()[mot_name]["value"]
|
|
else:
|
|
motor_pos = 0
|
|
method = self._model
|
|
value = int(method.eval(params=self._model_params, x=motor_pos))
|
|
return self._add_noise(value, self.params["noise"], self.params["noise_multiplier"])
|
|
|
|
def _add_noise(self, v: int, noise: NoiseType, noise_multiplier: float) -> int:
|
|
"""
|
|
Add the currently activated noise to the simulated data.
|
|
If NoiseType.NONE is active, the value will be returned
|
|
|
|
Args:
|
|
v (int): Value to add noise to.
|
|
Returns:
|
|
int: Value with added noise.
|
|
"""
|
|
if noise == NoiseType.POISSON:
|
|
v = np.random.poisson(v)
|
|
return v
|
|
elif noise == NoiseType.UNIFORM:
|
|
noise = np.ceil(np.random.uniform(0, 1) * noise_multiplier).astype(int)
|
|
v += noise * (np.random.randint(0, 2) * 2 - 1)
|
|
return v if v > 0 else 0
|
|
return v
|
|
|
|
|
|
class SimulatedDataWaveform(SimulatedDataMonitor):
|
|
"""Simulated data class for a waveform.
|
|
|
|
The class inherits from SimulatedDataMonitor,
|
|
and overwrites the relevant methods to compute
|
|
a simulated waveform for each point.
|
|
"""
|
|
|
|
def _get_additional_params(self) -> None:
|
|
params = deepcopy(DEFAULT_PARAMS_NOISE)
|
|
return params
|
|
|
|
def compute_sim_state(self, signal_name: str, compute_readback: bool) -> None:
|
|
"""Update the simulated state of the device.
|
|
|
|
It will update the value in self.sim_state with the value computed by
|
|
the chosen simulation type.
|
|
|
|
Args:
|
|
signal_name (str): Name of the signal to update.
|
|
"""
|
|
if compute_readback:
|
|
method = self._compute
|
|
value = self.execute_simulation_method(method=method, signal_name=signal_name)
|
|
value = self.bit_depth(value)
|
|
self.update_sim_state(signal_name, value)
|
|
|
|
def _compute(self, *args, **kwargs) -> np.ndarray:
|
|
"""
|
|
Compute the return value for active model.
|
|
|
|
Returns:
|
|
np.array: Values computed for the activate model.
|
|
"""
|
|
size = self.parent.waveform_shape.get()
|
|
size = size[0] if isinstance(size, tuple) else size
|
|
method = self._model
|
|
value = method.eval(params=self._model_params, x=np.array(range(size)))
|
|
# Upscale the normalised gaussian if possible
|
|
if "amplitude" in method.param_names:
|
|
value *= self.params["amplitude"] / np.max(value)
|
|
return self._add_noise(value, self.params["noise"], self.params["noise_multiplier"])
|
|
|
|
def _add_noise(self, v: np.ndarray, noise: NoiseType, noise_multiplier: float) -> np.ndarray:
|
|
"""Add noise to the simulated data.
|
|
|
|
Args:
|
|
v (np.ndarray): Simulated data.
|
|
noise (NoiseType): Type of noise to add.
|
|
"""
|
|
if noise == NoiseType.POISSON:
|
|
v = np.random.poisson(np.round(v), v.shape)
|
|
return v
|
|
if noise == NoiseType.UNIFORM:
|
|
v += np.random.uniform(-noise_multiplier, noise_multiplier, v.shape)
|
|
v[v <= 0] = 0
|
|
return v
|
|
if noise == NoiseType.NONE:
|
|
return v
|
|
|
|
|
|
class SimulatedDataCamera(SimulatedDataBase):
|
|
"""Simulated class to compute data for a 2D camera."""
|
|
|
|
def __init__(self, *args, parent=None, **kwargs) -> None:
|
|
self._model_lookup = self.init_2D_models()
|
|
self._all_default_model_params = defaultdict(dict)
|
|
self._init_default_camera_params()
|
|
super().__init__(*args, parent=parent, **kwargs)
|
|
self.bit_depth = self.parent.BIT_DEPTH
|
|
self._init_default()
|
|
|
|
def _init_default(self) -> None:
|
|
"""Initialize the default model for a simulated camera
|
|
|
|
Use the default model "Gaussian".
|
|
"""
|
|
self.select_model(SimulationType2D.GAUSSIAN)
|
|
|
|
def init_2D_models(self) -> dict:
|
|
"""
|
|
Get the available models for 2D camera simulations.
|
|
"""
|
|
model_lookup = {}
|
|
for _, model_cls in inspect.getmembers(SimulationType2D):
|
|
if isinstance(model_cls, SimulationType2D):
|
|
model_lookup[model_cls.value] = model_cls
|
|
return model_lookup
|
|
|
|
def _get_additional_params(self) -> None:
|
|
params = deepcopy(DEFAULT_PARAMS_NOISE)
|
|
params.update(deepcopy(DEFAULT_PARAMS_HOT_PIXEL))
|
|
return params
|
|
|
|
def _init_default_camera_params(self) -> None:
|
|
"""Initiate additional params for the simulated camera."""
|
|
self._all_default_model_params.update(
|
|
{
|
|
self._model_lookup[SimulationType2D.CONSTANT.value]: deepcopy(
|
|
DEFAULT_PARAMS_CAMERA_CONSTANT
|
|
)
|
|
}
|
|
)
|
|
self._all_default_model_params.update(
|
|
{
|
|
self._model_lookup[SimulationType2D.GAUSSIAN.value]: deepcopy(
|
|
DEFAULT_PARAMS_CAMERA_GAUSSIAN
|
|
)
|
|
}
|
|
)
|
|
|
|
def get_model_cls(self, model: str) -> any:
|
|
"""For the simulated positioners, no simulation models are currently implemented."""
|
|
if model not in self._model_lookup:
|
|
raise SimulatedDataException(f"Model {model} not found in {self._model_lookup.keys()}.")
|
|
return self._model_lookup[model]
|
|
|
|
def get_params_for_model_cls(self) -> dict:
|
|
"""For the simulated positioners, no simulation models are currently implemented."""
|
|
return self._all_default_model_params[self._model.value]
|
|
|
|
def get_all_sim_models(self) -> list[str]:
|
|
"""
|
|
For the simulated positioners, no simulation models are currently implemented.
|
|
|
|
Returns:
|
|
list: List of available simulation models.
|
|
"""
|
|
return [entry.value for entry in self._model_lookup.values()]
|
|
|
|
def compute_sim_state(self, signal_name: str, compute_readback: bool) -> None:
|
|
"""Update the simulated state of the device.
|
|
|
|
It will update the value in self.sim_state with the value computed by
|
|
the chosen simulation type.
|
|
|
|
Args:
|
|
signal_name (str) : Name of the signal to update.
|
|
compute_readback (bool) : Flag whether to compute readback based on function hosted in SimulatedData
|
|
"""
|
|
if compute_readback:
|
|
if self._model == SimulationType2D.CONSTANT:
|
|
method = "_compute_constant"
|
|
elif self._model == SimulationType2D.GAUSSIAN:
|
|
method = "_compute_gaussian"
|
|
else:
|
|
raise SimulatedDataException(
|
|
f"Model {self._model} not found in {self._model_lookup.keys()}."
|
|
)
|
|
value = self.execute_simulation_method(
|
|
signal_name=signal_name, method=getattr(self, method)
|
|
)
|
|
else:
|
|
value = self._compute_empty_image()
|
|
value = self.bit_depth(value)
|
|
self.update_sim_state(signal_name, value)
|
|
|
|
def _compute_empty_image(self) -> np.ndarray:
|
|
"""Computes return value for sim_type = "empty_image".
|
|
|
|
Returns:
|
|
float: 0
|
|
"""
|
|
try:
|
|
shape = self.parent.image_shape.get()
|
|
return np.zeros(shape)
|
|
except SimulatedDataException as exc:
|
|
raise SimulatedDataException(
|
|
f"Could not compute empty image for {self.parent.name} with {exc} raised. Deactivate eiger to continue."
|
|
) from exc
|
|
|
|
def _compute_constant(self) -> np.ndarray:
|
|
"""Compute a return value for SimulationType2D constant."""
|
|
try:
|
|
shape = self.parent.image_shape.get()
|
|
v = self.params.get("amplitude") * np.ones(shape, dtype=np.float32)
|
|
v = self._add_noise(v, self.params["noise"], self.params["noise_multiplier"])
|
|
return self._add_hot_pixel(
|
|
v,
|
|
coords=self.params["hot_pixel_coords"],
|
|
hot_pixel_types=self.params["hot_pixel_types"],
|
|
values=self.params["hot_pixel_values"],
|
|
)
|
|
except SimulatedDataException as exc:
|
|
raise SimulatedDataException(
|
|
f"Could not compute constant for {self.parent.name} with {exc} raised. Deactivate eiger to continue."
|
|
) from exc
|
|
|
|
def _compute_gaussian(self) -> float:
|
|
"""Computes return value for sim_type = "gauss".
|
|
|
|
The value is based on the parameters for the gaussian in
|
|
self._active_params and adds noise based on the noise type.
|
|
|
|
If computation fails, it returns 0.
|
|
|
|
Returns: float
|
|
"""
|
|
|
|
try:
|
|
amp = self.params.get("amplitude")
|
|
cov = self.params.get("covariance")
|
|
cen_off = self.params.get("center_offset")
|
|
shape = self.sim_state[self.parent.image_shape.name]["value"]
|
|
pos, offset, cov, amp = self._prepare_params_gauss(
|
|
amp=amp, cov=cov, offset=cen_off, shape=shape
|
|
)
|
|
|
|
v = self._compute_multivariate_gaussian(pos=pos, cen_off=offset, cov=cov, amp=amp)
|
|
v = self._add_noise(
|
|
v, noise=self.params["noise"], noise_multiplier=self.params["noise_multiplier"]
|
|
)
|
|
return self._add_hot_pixel(
|
|
v,
|
|
coords=self.params["hot_pixel_coords"],
|
|
hot_pixel_types=self.params["hot_pixel_types"],
|
|
values=self.params["hot_pixel_values"],
|
|
)
|
|
except SimulatedDataException as exc:
|
|
raise SimulatedDataException(
|
|
f"Could not compute gaussian for {self.parent.name} with {exc} raised. Deactivate eiger to continue."
|
|
) from exc
|
|
|
|
def _compute_multivariate_gaussian(
|
|
self, pos: np.ndarray | list, cen_off: np.ndarray | list, cov: np.ndarray | list, amp: float
|
|
) -> np.ndarray:
|
|
"""Computes and returns the multivariate Gaussian distribution.
|
|
|
|
Args:
|
|
pos (np.ndarray): Position of the gaussian.
|
|
cen_off (np.ndarray): Offset from center of image for the gaussian.
|
|
cov (np.ndarray): Covariance matrix of the gaussian.
|
|
|
|
Returns:
|
|
np.ndarray: Multivariate Gaussian distribution.
|
|
"""
|
|
if isinstance(pos, list):
|
|
pos = np.array(pos)
|
|
if isinstance(cen_off, list):
|
|
cen_off = np.array(cen_off)
|
|
if isinstance(cov, list):
|
|
cov = np.array(cov)
|
|
dim = cen_off.shape[0]
|
|
cov_det = np.linalg.det(cov)
|
|
cov_inv = np.linalg.inv(cov)
|
|
norm = np.sqrt((2 * np.pi) ** dim * cov_det)
|
|
# This einsum call calculates (x-mu)T.Sigma-1.(x-mu) in a vectorized
|
|
# way across all the input variables.
|
|
fac = np.einsum("...k,kl,...l->...", pos - cen_off, cov_inv, pos - cen_off)
|
|
v = np.exp(-fac / 2) / norm
|
|
v *= amp / np.max(v)
|
|
return v
|
|
|
|
def _prepare_params_gauss(
|
|
self, amp: float, cov: np.ndarray, offset: np.ndarray, shape: tuple
|
|
) -> tuple:
|
|
"""Prepare the positions for the gaussian.
|
|
|
|
Args:
|
|
amp (float): Amplitude of the gaussian.
|
|
cov (np.ndarray): Covariance matrix of the gaussian.
|
|
offset (np.ndarray): Offset from the center of the image.
|
|
shape (tuple): Shape of the image.
|
|
Returns:
|
|
tuple: Positions, offset and covariance matrix for the gaussian.
|
|
"""
|
|
x, y = np.meshgrid(
|
|
np.linspace(-shape[0] / 2, shape[0] / 2, shape[0]),
|
|
np.linspace(-shape[1] / 2, shape[1] / 2, shape[1]),
|
|
)
|
|
pos = np.empty((*x.shape, 2))
|
|
pos[:, :, 0] = x
|
|
pos[:, :, 1] = y
|
|
|
|
return pos, offset, cov, amp
|
|
|
|
def _add_noise(self, v: np.ndarray, noise: NoiseType, noise_multiplier: float) -> np.ndarray:
|
|
"""Add noise to the simulated data.
|
|
|
|
Args:
|
|
v (np.ndarray): Simulated data.
|
|
noise (NoiseType): Type of noise to add.
|
|
"""
|
|
if noise == NoiseType.POISSON:
|
|
v = np.random.poisson(np.round(v), v.shape)
|
|
return v
|
|
if noise == NoiseType.UNIFORM:
|
|
v += np.random.uniform(-noise_multiplier, noise_multiplier, v.shape)
|
|
v[v <= 0] = 0
|
|
return v
|
|
if noise == NoiseType.NONE:
|
|
return v
|
|
|
|
def _add_hot_pixel(
|
|
self, v: np.ndarray, coords: list, hot_pixel_types: list, values: list
|
|
) -> np.ndarray:
|
|
"""Add hot pixels to the simulated data.
|
|
|
|
Args:
|
|
v (np.ndarray): Simulated data.
|
|
hot_pixel (dict): Hot pixel parameters.
|
|
"""
|
|
for coord, hot_pixel_type, value in zip(coords, hot_pixel_types, values):
|
|
if coord[0] < v.shape[0] and coord[1] < v.shape[1]:
|
|
if hot_pixel_type == HotPixelType.CONSTANT:
|
|
v[coord[0], coord[1]] = value
|
|
elif hot_pixel_type == HotPixelType.FLUCTUATING:
|
|
maximum = np.max(v) if np.max(v) != 0 else 1
|
|
if v[coord[0], coord[1]] / maximum > 0.5:
|
|
v[coord[0], coord[1]] = value
|
|
return v
|