refactor: mcs_card inherits from base class psi_detector_base

This commit is contained in:
2023-11-20 18:06:25 +01:00
parent 3a37de9eda
commit d77e8e255d

View File

@ -1,33 +1,30 @@
import enum import enum
import threading import threading
import time from typing import List
from typing import Any, List from collections import defaultdict
import numpy as np import numpy as np
from ophyd import EpicsSignal, EpicsSignalRO from ophyd import EpicsSignal, EpicsSignalRO, Device, Component as Cpt
from ophyd import EpicsSignal, EpicsSignalRO, Component as Cpt, Device
from ophyd.mca import EpicsMCARecord
from ophyd.scaler import ScalerCH
from ophyd_devices.epics.devices.bec_scaninfo_mixin import BecScaninfoMixin
from ophyd_devices.utils import bec_utils from ophyd_devices.utils import bec_utils
from ophyd_devices.epics.devices.psi_detector_base import PSIDetectorBase, CustomDetectorMixin
from bec_lib import messages, MessageEndpoints, bec_logger, threadlocked from bec_lib import messages, MessageEndpoints, bec_logger, threadlocked
from bec_lib.file_utils import FileWriterMixin
from collections import defaultdict
logger = bec_logger.logger logger = bec_logger.logger
class McsError(Exception): class MCSError(Exception):
pass """Base class for exceptions in this module."""
class McsTimeoutError(Exception): class MCSTimeoutError(MCSError):
pass """Raise when MCS card runs into a timeout"""
class TriggerSource(int, enum.Enum): class TriggerSource(int, enum.Enum):
"""Trigger source for mcs card - see manual for more information"""
MODE0 = 0 MODE0 = 0
MODE1 = 1 MODE1 = 1
MODE2 = 2 MODE2 = 2
@ -38,11 +35,15 @@ class TriggerSource(int, enum.Enum):
class ChannelAdvance(int, enum.Enum): class ChannelAdvance(int, enum.Enum):
"""Channel advance pixel mode for mcs card - see manual for more information"""
INTERNAL = 0 INTERNAL = 0
EXTERNAL = 1 EXTERNAL = 1
class ReadoutMode(int, enum.Enum): class ReadoutMode(int, enum.Enum):
"""Readout mode for mcs card - see manual for more information"""
PASSIVE = 0 PASSIVE = 0
EVENT = 1 EVENT = 1
IO_INTR = 2 IO_INTR = 2
@ -56,20 +57,192 @@ class ReadoutMode(int, enum.Enum):
FREQ_100HZ = 10 FREQ_100HZ = 10
class SIS38XX(Device): class MCSSetup(CustomDetectorMixin):
"""SIS38XX control""" """Setup mixin class for the MCS card"""
def __init__(self, *args, parent: Device = None, **kwargs) -> None:
self._lock = threading.RLock()
self._stream_ttl = 1800
self.acquisition_done = False
self.counter = 0
self.n_points = 0
self.mca_names = [
signal for signal in self.parent.component_names if signal.startswith("mca")
]
self.mca_data = defaultdict(lambda: [])
super().__init__(*args, parent=parent, **kwargs)
def initialize_detector(self) -> None:
"""Initialize detector"""
# External trigger for pixel advance
self.parent.channel_advance.set(ChannelAdvance.EXTERNAL)
# Use internal clock for channel 1
self.parent.channel1_source.set(ChannelAdvance.INTERNAL)
self.parent.user_led.set(0)
# Set number of channels to 5
self.parent.mux_output.set(5)
# Trigger Mode used for cSAXS
self.parent.set_trigger(TriggerSource.MODE3)
# specify polarity of trigger signals
self.parent.input_polarity.set(0)
self.parent.output_polarity.set(1)
# do not start counting on start
self.parent.count_on_start.set(0)
self.parent.stop_all.set(1)
def initialize_detector_backend(self) -> None:
"""Initialize detector backend"""
for mca in self.mca_names:
signal = getattr(self.parent, mca)
signal.subscribe(self._on_mca_data, run=False)
self.parent.current_channel.subscribe(self._progress_update, run=False)
def _progress_update(self, value, **kwargs) -> None:
"""Progress update on the scan"""
num_lines = self.parent.num_lines.get()
max_value = self.parent.scaninfo.num_points
# self.counter seems to be a deprecated variable from a former implementation of the mcs card
# pylint: disable=protected-access
self.parent._run_subs(
sub_type=self.parent.SUB_PROGRESS,
value=self.counter * int(self.parent.scaninfo.num_points / num_lines) + value,
max_value=max_value,
done=bool(max_value == self.counter),
)
@threadlocked
def _on_mca_data(self, *args, obj=None, value=None, **kwargs) -> None:
"""Callback function for scan progress"""
if not isinstance(value, (list, np.ndarray)):
return
self.mca_data[obj.attr_name] = value
if len(self.mca_names) != len(self.mca_data):
return
self.acquisition_done = True
self._send_data_to_bec()
self.mca_data = defaultdict(lambda: [])
def _send_data_to_bec(self) -> None:
"""Sends bundled data to BEC"""
if self.parent.scaninfo.scan_msg is None:
return
metadata = self.parent.scaninfo.scan_msg.metadata
metadata.update(
{
"async_update": "append",
"num_lines": self.parent.num_lines.get(),
}
)
msg = messages.DeviceMessage(
signals=dict(self.mca_data),
metadata=self.parent.scaninfo.scan_msg.metadata,
).dumps()
self.parent.producer.xadd(
topic=MessageEndpoints.device_async_readback(
scanID=self.parent.scaninfo.scanID, device=self.parent.name
),
msg={"data": msg},
expire=self._stream_ttl,
)
def prepare_detector(self) -> None:
"""Prepare detector for scan"""
self.set_acquisition_params()
self.parent.set_trigger(TriggerSource.MODE3)
def set_acquisition_params(self) -> None:
"""Set acquisition parameters for scan"""
if self.parent.scaninfo.scan_type == "step":
self.n_points = int(self.parent.scaninfo.frames_per_trigger) * int(
self.parent.scaninfo.num_points
)
elif self.parent.scaninfo.scan_type == "fly":
self.n_points = int(self.parent.scaninfo.num_points) # / int(self.num_lines.get()))
else:
raise MCSError(f"Scantype {self.parent.scaninfo} not implemented for MCS card")
if self.n_points > 10000:
raise MCSError(
f"Requested number of points N={self.n_points} exceeds hardware limit of mcs card"
" 10000 (N-1)"
)
self.parent.num_use_all.set(self.n_points)
self.parent.preset_real.set(0)
def prepare_detector_backend(self) -> None:
"""Prepare detector backend for scan"""
self.parent.erase_all.set(1)
self.parent.read_mode.set(ReadoutMode.EVENT)
def arm_acquisition(self) -> None:
"""Arm detector for acquisition"""
self.counter = 0
self.parent.erase_start.set(1)
def finished(self) -> None:
"""Check if acquisition is finished, if not successful, rais MCSTimeoutError"""
signal_conditions = [
(
self.acquisition_done,
True,
),
(
self.parent.acquiring.get,
0, # Considering making a enum.Int class for this state
),
]
if not self.wait_for_signals(
signal_conditions=signal_conditions,
timeout=self.parent.timeout,
check_stopped=True,
all_signals=True,
):
total_frames = self.counter * int(
self.parent.scaninfo.num_points / self.num_lines.get()
) + max(self.parent.current_channel.get(), 0)
raise MCSTimeoutError(
f"Reached timeout with mcs in state {self.parent.acquiring.get()} and"
f" {total_frames} frames arriving at the mcs card"
)
def stop_detector(self) -> None:
"""Stop detector"""
self.parent.stop_all.set(1)
return super().stop_detector()
def stop_detector_backend(self) -> None:
"""Stop acquisition of data"""
self.acquisition_done = True
class MCScSAXS(PSIDetectorBase):
"""MCS card for cSAXS for implementation at cSAXS beamline"""
USER_ACCESS = ["describe", "_init_mcs"]
SUB_PROGRESS = "progress"
SUB_VALUE = "value"
_default_sub = SUB_VALUE
# specify Setup class
custom_prepare_cls = MCSSetup
# specify minimum readout time for detector
MIN_READOUT = 3e-3
# PV access to MCA signals
mca1 = Cpt(EpicsSignalRO, "mca1.VAL", auto_monitor=True)
mca3 = Cpt(EpicsSignalRO, "mca3.VAL", auto_monitor=True)
mca4 = Cpt(EpicsSignalRO, "mca4.VAL", auto_monitor=True)
current_channel = Cpt(EpicsSignalRO, "CurrentChannel", auto_monitor=True)
# PV access to SISS38XX card
# Acquisition # Acquisition
erase_all = Cpt(EpicsSignal, "EraseAll") erase_all = Cpt(EpicsSignal, "EraseAll")
erase_start = Cpt(EpicsSignal, "EraseStart", trigger_value=1) erase_start = Cpt(EpicsSignal, "EraseStart", trigger_value=1)
start_all = Cpt(EpicsSignal, "StartAll") start_all = Cpt(EpicsSignal, "StartAll")
stop_all = Cpt(EpicsSignal, "StopAll") stop_all = Cpt(EpicsSignal, "StopAll")
acquiring = Cpt(EpicsSignal, "Acquiring") acquiring = Cpt(EpicsSignal, "Acquiring")
preset_real = Cpt(EpicsSignal, "PresetReal") preset_real = Cpt(EpicsSignal, "PresetReal")
elapsed_real = Cpt(EpicsSignal, "ElapsedReal") elapsed_real = Cpt(EpicsSignal, "ElapsedReal")
read_mode = Cpt(EpicsSignal, "ReadAll.SCAN") read_mode = Cpt(EpicsSignal, "ReadAll.SCAN")
read_all = Cpt(EpicsSignal, "DoReadAll.VAL", trigger_value=1) read_all = Cpt(EpicsSignal, "DoReadAll.VAL", trigger_value=1)
num_use_all = Cpt(EpicsSignal, "NuseAll") num_use_all = Cpt(EpicsSignal, "NuseAll")
@ -93,48 +266,7 @@ class SIS38XX(Device):
firmware = Cpt(EpicsSignalRO, "Firmware") firmware = Cpt(EpicsSignalRO, "Firmware")
max_channels = Cpt(EpicsSignalRO, "MaxChannels") max_channels = Cpt(EpicsSignalRO, "MaxChannels")
# Custom signal readout from device config
class McsCsaxs(SIS38XX):
USER_ACCESS = ["_init_mcs"]
SUB_PROGRESS = "progress"
SUB_VALUE = "value"
_default_sub = SUB_VALUE
# scaler = Cpt(ScalerCH, "scaler1")
# mca2 = Cpt(EpicsMCARecord, "mca2")
mca1 = Cpt(EpicsSignalRO, "mca1.VAL", auto_monitor=True)
mca3 = Cpt(EpicsSignalRO, "mca3.VAL", auto_monitor=True)
mca4 = Cpt(EpicsSignalRO, "mca4.VAL", auto_monitor=True)
# mca5 = Cpt(EpicsMCARecord, "mca5")
# mca6 = Cpt(EpicsMCARecord, "mca6")
# mca7 = Cpt(EpicsMCARecord, "mca7")
# mca8 = Cpt(EpicsMCARecord, "mca8")
# mca9 = Cpt(EpicsMCARecord, "mca9")
# mca10 = Cpt(EpicsMCARecord, "mca10")
# mca11 = Cpt(EpicsMCARecord, "mca11")
# mca12 = Cpt(EpicsMCARecord, "mca12")
# mca13 = Cpt(EpicsMCARecord, "mca13")
# mca14 = Cpt(EpicsMCARecord, "mca14")
# mca15 = Cpt(EpicsMCARecord, "mca15")
# mca16 = Cpt(EpicsMCARecord, "mca16")
# mca17 = Cpt(EpicsMCARecord, "mca17")
# mca18 = Cpt(EpicsMCARecord, "mca18")
# mca19 = Cpt(EpicsMCARecord, "mca19")
# mca20 = Cpt(EpicsMCARecord, "mca20")
# mca21 = Cpt(EpicsMCARecord, "mca21")
# mca22 = Cpt(EpicsMCARecord, "mca22")
# mca23 = Cpt(EpicsMCARecord, "mca23")
# mca24 = Cpt(EpicsMCARecord, "mca24")
# mca25 = Cpt(EpicsMCARecord, "mca25")
# mca26 = Cpt(EpicsMCARecord, "mca26")
# mca27 = Cpt(EpicsMCARecord, "mca27")
# mca28 = Cpt(EpicsMCARecord, "mca28")
# mca29 = Cpt(EpicsMCARecord, "mca29")
# mca30 = Cpt(EpicsMCARecord, "mca30")
# mca31 = Cpt(EpicsMCARecord, "mca31")
# mca32 = Cpt(EpicsMCARecord, "mca32")
current_channel = Cpt(EpicsSignalRO, "CurrentChannel", auto_monitor=True)
num_lines = Cpt( num_lines = Cpt(
bec_utils.ConfigSignal, bec_utils.ConfigSignal,
name="num_lines", name="num_lines",
@ -160,6 +292,7 @@ class McsCsaxs(SIS38XX):
f"{name}_num_lines": 1, f"{name}_num_lines": 1,
} }
if mcs_config is not None: if mcs_config is not None:
# pylint: disable=expression-not-assigned
[self.mcs_config.update({f"{name}_{key}": value}) for key, value in mcs_config.items()] [self.mcs_config.update({f"{name}_{key}": value}) for key, value in mcs_config.items()]
super().__init__( super().__init__(
@ -169,246 +302,23 @@ class McsCsaxs(SIS38XX):
read_attrs=read_attrs, read_attrs=read_attrs,
configuration_attrs=configuration_attrs, configuration_attrs=configuration_attrs,
parent=parent, parent=parent,
device_manager=device_manager,
sim_mode=sim_mode,
**kwargs, **kwargs,
) )
if device_manager is None and not sim_mode:
raise McsError("Add DeviceManager to initialization or init with sim_mode=True")
self.name = name def set_trigger(self, trigger_source: TriggerSource) -> None:
self._stream_ttl = 1800 """Set trigger mode from TriggerSource"""
self.wait_for_connection() # Make sure to be connected before talking to PVs
if not sim_mode:
self.device_manager = device_manager
self._producer = self.device_manager.producer
else:
self._producer = bec_utils.MockProducer()
self.device_manager = bec_utils.MockDeviceManager()
# TODO mack mock connector class
# self._consumer = self.device_manager.connector.consumer
self.scaninfo = BecScaninfoMixin(device_manager, sim_mode)
# TODO
self.scaninfo.username = "e21206"
self.service_cfg = {"base_path": f"/sls/X12SA/data/{self.scaninfo.username}/Data10/"}
self.filewriter = FileWriterMixin(self.service_cfg)
self._stopped = False
self._acquisition_done = False
self._lock = threading.RLock()
self.counter = 0
self.n_points = 0
self._init_mcs()
self.mca_names = [signal for signal in self.component_names if signal.startswith("mca")]
self.mca_data = defaultdict(lambda: [])
for mca in self.mca_names:
signal = getattr(self, mca)
signal.subscribe(self._on_mca_data, run=False)
self.current_channel.subscribe(self._progress_update, run=False)
def _init_mcs(self) -> None:
"""Init parameters for mcs card 9m
channel_advance: 0/1 -> internal / external
channel1_source: 0/1 -> int clock / external source
user_led: 0/1 -> off/on
max_output : num of channels 0...32, uncomment top for more than 5
input_mode: operation mode -> Mode 3 for external trigger, check manual for more info
input_polarity: triggered between falling and falling edge -> use inverted signal from ddg
"""
self.channel_advance.set(ChannelAdvance.EXTERNAL)
self.channel1_source.set(ChannelAdvance.INTERNAL)
self.user_led.set(0)
self.mux_output.set(5)
self._set_trigger(TriggerSource.MODE3)
self.input_polarity.set(0)
self.output_polarity.set(1)
self.count_on_start.set(0)
self.stop_all.set(1)
def _progress_update(self, value, **kwargs) -> None:
num_lines = self.num_lines.get()
max_value = self.scaninfo.num_points
self._run_subs(
sub_type=self.SUB_PROGRESS,
value=self.counter * int(self.scaninfo.num_points / num_lines) + value,
max_value=max_value,
done=bool(max_value == self.counter),
)
@threadlocked
def _on_mca_data(self, *args, obj=None, **kwargs) -> None:
if not isinstance(kwargs["value"], (list, np.ndarray)):
return
# self.mca_data[obj.attr_name] = kwargs["value"][1:]
self.mca_data[obj.attr_name] = kwargs["value"]
if len(self.mca_names) != len(self.mca_data):
return
# logger.info("Entered _on_mca_data")
# self._updated = True
# self.counter += 1
# logger.info(f'data from mca {self.mca_data["mca1"]} and {self.mca_data["mca4"]}')
# if (self.scaninfo.scan_type == "fly" and self.counter == self.num_lines.get()) or (
# self.scaninfo.scan_type == "step" and self.counter == self.scaninfo.num_points
# ):
# self._acquisition_done = True
# self.stop_all.put(1, use_complete=False)
# self._send_data_to_bec()
# self.erase_all.put(1)
# #logger.info("Entered _on_mca_data, acquisition finished")
# # Require wait for
# # time.sleep(0.01)
# self.mca_data = defaultdict(lambda: [])
# self.counter = 0
# return
# self.erase_start.set(1)
# self._send_data_to_bec()
# self.mca_data = defaultdict(lambda: [])
self._acquisition_done = True
self._send_data_to_bec()
self.mca_data = defaultdict(lambda: [])
def _send_data_to_bec(self) -> None:
if self.scaninfo.scan_msg is None:
return
metadata = self.scaninfo.scan_msg.metadata
metadata.update(
{
"async_update": "append",
"num_lines": self.num_lines.get(),
}
)
msg = messages.DeviceMessage(
signals=dict(self.mca_data),
metadata=self.scaninfo.scan_msg.metadata,
).dumps()
self._producer.xadd(
topic=MessageEndpoints.device_async_readback(
scanID=self.scaninfo.scanID, device=self.name
),
msg={"data": msg},
expire=self._stream_ttl,
)
def _prep_det(self) -> None:
self._set_acquisition_params()
self._set_trigger(TriggerSource.MODE3)
def _set_acquisition_params(self) -> None:
if self.scaninfo.scan_type == "step":
self.n_points = int(self.scaninfo.frames_per_trigger) * int(self.scaninfo.num_points)
elif self.scaninfo.scan_type == "fly":
self.n_points = int(self.scaninfo.num_points) # / int(self.num_lines.get()))
else:
raise McsError(f"Scantype {self.scaninfo} not implemented for MCS card")
if self.n_points > 10000:
raise McsError(
f"Requested number of points N={self.n_points} exceeds hardware limit of mcs card 10000 (N-1)"
)
self.num_use_all.set(self.n_points)
self.preset_real.set(0)
def _set_trigger(self, trigger_source: TriggerSource) -> None:
"""7 Modes, see TriggerSource
Mode3 for cSAXS"""
value = int(trigger_source) value = int(trigger_source)
self.input_mode.set(value) self.input_mode.set(value)
def _prep_readout(self) -> None:
"""Set readout mode of mcs card
Check ReadoutMode class for more information about options
"""
# self.read_mode.set(ReadoutMode.EVENT)
self.erase_all.set(1)
self.read_mode.set(ReadoutMode.EVENT)
def _force_readout_mcs_card(self) -> None:
self.read_all.put(1, use_complete=False)
def stage(self) -> List[object]: def stage(self) -> List[object]:
"""stage the detector and file writer""" """stage the detector for upcoming acquisition"""
self._stopped = False rtr = super().stage()
self._acquisition_done = False self.custom_prepare.arm_acquisition()
logger.info("Stage mcs") return rtr
self.scaninfo.load_scan_metadata()
self._prep_det()
self._prep_readout()
# msg = messages.FileMessage(file_path=self.filepath, done=False)
# self._producer.set_and_publish(
# MessageEndpoints.public_file(self.scaninfo.scanID, "mcs_csaxs"),
# msg.dumps(),
# )
self.arm_acquisition()
logger.info("Waiting for mcs to be armed")
while True:
det_ctrl = self.acquiring.read()[self.acquiring.name]["value"]
if det_ctrl == 1:
break
time.sleep(0.005)
logger.info("mcs is ready and running")
# time.sleep(5)
return super().stage()
def unstage(self) -> List[object]:
"""unstage"""
logger.info("Waiting for mcs to finish acquisition")
old_scanID = self.scaninfo.scanID
self.scaninfo.load_scan_metadata()
logger.info(f"Old scanID: {old_scanID}, ")
if self.scaninfo.scanID != old_scanID:
self._stopped = True
if self._stopped is True:
logger.info("Entered unstage _stopped =True")
return super().unstage()
self._mcs_finished()
self._acquisition_done = False
self._stopped = False
logger.info("mcs done")
return super().unstage()
def _mcs_finished(self):
"""Function with 10s timeout"""
timer = 0
logger.info("Entered _mcs_finished loop")
while True:
if self._acquisition_done == True and self.acquiring.get() == 0:
break
if self._stopped == True:
break
time.sleep(0.1)
timer += 0.1
if timer > 8:
total_frames = self.counter * int(
self.scaninfo.num_points / self.num_lines.get()
) + max(self.current_channel.get(), 0)
raise McsTimeoutError(
f"Reached timeout with mcs in state {self.acquiring.get()} and {total_frames} frames arriving at the mcs card"
)
logger.info("Finished _mcs_finished loop")
def arm_acquisition(self) -> None:
"""Arm acquisition
Options:
Start: start_all
Erase/Start: erase_start
"""
logger.info("Entered mcs arm_acquisition")
self.counter = 0
self.erase_start.set(1)
# self.start_all.set(1)
def stop(self, *, success=False) -> None:
"""Stop acquisition
Stop or Stop and Erase
"""
logger.info("Entered mcs stop")
self.stop_all.set(1)
# self.erase_all.set(1)
self._stopped = True
self._acquisition_done = True
super().stop(success=success)
# Automatically connect to test environmenr if directly invoked # Automatically connect to test environmenr if directly invoked
if __name__ == "__main__": if __name__ == "__main__":
mcs = McsCsaxs(name="mcs", prefix="X12SA-MCS:", sim_mode=True) mcs = MCScSAXS(name="mcs", prefix="X12SA-MCS:", sim_mode=True)
mcs.stage()
mcs.unstage()