refactor(mcs-card): adjust mcs card to only have mca channels.

This commit is contained in:
2025-12-08 17:37:39 +01:00
parent 2cf2f4b4e4
commit ef0c31c8dc

View File

@@ -2,16 +2,17 @@
from __future__ import annotations
import enum
import time
from threading import RLock
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal
import numpy as np
from bec_lib.logger import bec_logger
from ophyd import Component as Cpt
from ophyd import Device, EpicsSignalRO, Kind, Signal
from ophyd_devices import CompareStatus, ProgressSignal, TransitionStatus
from ophyd import EpicsSignalRO, Kind, SignalRO
from ophyd_devices import AsyncSignal, CompareStatus, ProgressSignal, TransitionStatus
from ophyd_devices.interfaces.base_classes.psi_device_base import PSIDeviceBase
from ophyd_devices.utils.bec_signals import AsyncMultiSignal
from csaxs_bec.devices.epics.mcs_card.mcs_card import (
ACQUIREMODE,
@@ -24,7 +25,6 @@ from csaxs_bec.devices.epics.mcs_card.mcs_card import (
READMODE,
MCSCard,
)
from csaxs_bec.devices.epics.xbpms import DiffXYSignal, SumSignal
if TYPE_CHECKING: # pragma: no cover
from bec_lib.devicemanager import DeviceManagerBase, ScanInfo
@@ -32,81 +32,289 @@ if TYPE_CHECKING: # pragma: no cover
logger = bec_logger.logger
class READYTOREAD(int, enum.Enum):
PROCESSING = 0
DONE = 1
class BPMDevice(Device):
"""Class for BPM device of the MCSCard."""
current1 = Cpt(Signal, kind=Kind.normal, doc="Normalized current 1")
current2 = Cpt(Signal, kind=Kind.normal, doc="Normalized current 2")
current3 = Cpt(Signal, kind=Kind.normal, doc="Normalized current 3")
current4 = Cpt(Signal, kind=Kind.normal, doc="Normalized current 4")
count_time = Cpt(Signal, kind=Kind.normal, doc="Count time for bpm signal counts")
sum = Cpt(SumSignal, kind="hinted", doc="Sum of all currents")
x = Cpt(
DiffXYSignal,
sum1=["current1", "current2"],
sum2=["current3", "current4"],
doc="X difference signal",
)
y = Cpt(
DiffXYSignal,
sum1=["current1", "current3"],
sum2=["current2", "current4"],
doc="Y difference signal",
)
diag = Cpt(
DiffXYSignal,
sum1=["current1", "current4"],
sum2=["current2", "current3"],
doc="Diagonal difference signal",
)
class MCSRaw(Device):
"""Class for BPM device of the MCSCard with normalized currents."""
mca1 = Cpt(Signal, kind=Kind.normal, doc="Raw counts on mca1 channel")
mca2 = Cpt(Signal, kind=Kind.normal, doc="Raw counts on mca2 channel")
mca3 = Cpt(Signal, kind=Kind.normal, doc="Raw counts on mca3 channel")
mca4 = Cpt(Signal, kind=Kind.normal, doc="Raw counts on mca4 channel")
mca5 = Cpt(Signal, kind=Kind.normal, doc="Raw counts on mca5 channel")
class MCSCardCSAXS(PSIDeviceBase, MCSCard):
"""
Implementation of the MCSCard SIS3820 for CSAXS, prefix 'X12SA-MCS:'.
The basic functionality is inherited from the MCSCard class.
"""
ready_to_read = Cpt(
Signal,
kind=Kind.omitted,
doc="Signal that indicates if mcs card is ready to be read from after triggers. 0 not ready, 1 ready",
)
progress: ProgressSignal = Cpt(ProgressSignal, name="progress")
# Make this an async signal..
mcs = Cpt(
MCSRaw,
name="mcs",
# All counter from the MCS card.
# mca = Cpt(
# AsyncMultiSignal,
# name="counters",
# signals=[
# f"mca{i}" for i in range(1, 33)
# ], # This needs to be in sync with counters DynamicDeviceComponent
# ndim=0,
# async_update={"type": "add", "max_shape": [None]},
# max_size=1000,
# kind=Kind.normal,
# doc="AsyncMultiSignal for MCA card channels 1-32",
# )
mca1 = Cpt(
AsyncSignal,
name="mca1",
kind=Kind.normal,
doc="MCS device with raw current and count time readings",
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 1",
)
bpm = Cpt(
BPMDevice,
name="bpm",
mca2 = Cpt(
AsyncSignal,
name="mca2",
kind=Kind.normal,
doc="BPM device for MCSCard with count times and normalized currents",
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 2",
)
mca3 = Cpt(
AsyncSignal,
name="mca3",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 3",
)
mca4 = Cpt(
AsyncSignal,
name="mca4",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 4",
)
mca5 = Cpt(
AsyncSignal,
name="mca5",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 5",
)
mca6 = Cpt(
AsyncSignal,
name="mca6",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 6",
)
mca7 = Cpt(
AsyncSignal,
name="mca7",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 7",
)
mca8 = Cpt(
AsyncSignal,
name="mca8",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 8",
)
mca9 = Cpt(
AsyncSignal,
name="mca9",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 9",
)
mca10 = Cpt(
AsyncSignal,
name="mca10",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 10",
)
mca11 = Cpt(
AsyncSignal,
name="mca11",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 11",
)
mca12 = Cpt(
AsyncSignal,
name="mca12",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 12",
)
mca13 = Cpt(
AsyncSignal,
name="mca13",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 13",
)
mca14 = Cpt(
AsyncSignal,
name="mca14",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 14",
)
mca15 = Cpt(
AsyncSignal,
name="mca15",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 15",
)
mca16 = Cpt(
AsyncSignal,
name="mca16",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 16",
)
mca17 = Cpt(
AsyncSignal,
name="mca17",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 17",
)
mca18 = Cpt(
AsyncSignal,
name="mca18",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 18",
)
mca19 = Cpt(
AsyncSignal,
name="mca19",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 19",
)
mca20 = Cpt(
AsyncSignal,
name="mca20",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 20",
)
mca21 = Cpt(
AsyncSignal,
name="mca21",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 21",
)
mca22 = Cpt(
AsyncSignal,
name="mca22",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 22",
)
mca23 = Cpt(
AsyncSignal,
name="mca23",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 23",
)
mca24 = Cpt(
AsyncSignal,
name="mca24",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 24",
)
mca25 = Cpt(
AsyncSignal,
name="mca25",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 25",
)
mca26 = Cpt(
AsyncSignal,
name="mca26",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 26",
)
mca27 = Cpt(
AsyncSignal,
name="mca27",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 27",
)
mca28 = Cpt(
AsyncSignal,
name="mca28",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 28",
)
mca29 = Cpt(
AsyncSignal,
name="mca29",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 29",
)
mca30 = Cpt(
AsyncSignal,
name="mca30",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 30",
)
mca31 = Cpt(
AsyncSignal,
name="mca31",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 31",
)
mca32 = Cpt(
AsyncSignal,
name="mca32",
kind=Kind.normal,
async_update={"type": "add", "max_shape": [None]},
max_size=1000,
doc="AsyncSignal for MCA channel 32",
)
# Progress Signal
progress = Cpt(ProgressSignal, doc="ProgressSignal indicating the progress of the device")
def __init__(
self,
name: str,
prefix: str = "",
num_connected_channels: int = 5,
scan_info: ScanInfo | None = None,
device_manager: DeviceManagerBase | None = None,
**kwargs,
@@ -118,16 +326,14 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard):
name=name, prefix=prefix, scan_info=scan_info, device_manager=device_manager, **kwargs
)
self._mcs_clock = 1e7 # 10MHz clock -> 1e7 Hz
self._pv_timeout = 3 # TODO remove timeout once #129 in ophyd_devices is solved
self._pv_timeout = 2.0 # seconds
self._rlock = RLock() # Needed to ensure thread safety for counter updates
self.counter_mapping = { # Any mca counter that should be updated has to be added here
f"{self.counters.name}_mca1": "current1",
f"{self.counters.name}_mca2": "current2",
f"{self.counters.name}_mca3": "current3",
f"{self.counters.name}_mca4": "current4",
f"{self.counters.name}_mca5": "count_time",
}
self.counter_updated = []
self.num_connected_channels = num_connected_channels
self._received_updates: dict[
str, dict[Literal["value", "timestamp"], list[int] | float]
] = {}
self._acquisition_group: str = "monitored"
self._num_total_triggers: int = 0
def on_connected(self):
"""
@@ -136,14 +342,15 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard):
# Make sure card is not running
self.stop_all.put(1)
# TODO Check channel1_source !!
# Setup the MCS card settings
self.channel_advance.set(CHANNELADVANCE.EXTERNAL).wait(timeout=self._pv_timeout)
self.channel1_source.set(CHANNEL1SOURCE.EXTERNAL).wait(timeout=self._pv_timeout)
self.prescale.set(1).wait(timeout=self._pv_timeout)
# Set the user LED to off
self.user_led.set(0).wait(timeout=self._pv_timeout)
# Only channel 1-5 are connected so far, adjust if more are needed
self.mux_output.set(5).wait(timeout=self._pv_timeout)
# Set mux_output to number of connected channels. Connect channels in increasing order
self.mux_output.set(self.num_connected_channels).wait(timeout=self._pv_timeout)
# Set the input and output modes & polarities
self.input_mode.set(INPUTMODE.MODE_3).wait(timeout=self._pv_timeout)
self.input_polarity.set(POLARITY.NORMAL).wait(timeout=self._pv_timeout)
@@ -151,90 +358,60 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard):
self.output_polarity.set(POLARITY.NORMAL).wait(timeout=self._pv_timeout)
self.count_on_start.set(0).wait(timeout=self._pv_timeout)
# Set appropriate read mode
# Set ReadMode to PASSIVE, card will wait for external trigger to be read
self.read_mode.set(READMODE.PASSIVE).wait(timeout=self._pv_timeout)
# Set the acquire mode
self.acquire_mode.set(ACQUIREMODE.MCS).wait(timeout=self._pv_timeout)
# Subscribe the progress signal
# self.current_channel.subscribe(self._progress_update, run=False)
self.current_channel.subscribe(self._progress_update, run=False)
# Subscribe to the mca updates
for name in self.counter_mapping.keys():
sig: EpicsSignalRO = getattr(self.counters, name.split("_")[-1])
sig.subscribe(self._on_counter_update, run=False)
for sig in self.counters.component_names:
sig_obj: EpicsSignalRO = getattr(self.counters, sig)
sig_obj.subscribe(self._on_counter_update, run=False)
def _on_counter_update(self, value, **kwargs) -> None:
"""
Callback for counter updates of the mca channels (1-32).
The raw data is pushed to the mcs sub-device (MCSRaw). We need to ensure that
the MCSRaw device has all signals defined for which we want to push the values.
Data from the mca channels will be pushed to a list, and then forwarded to
the async multi signal 'raw' for readout after the trigger is complete.
As we may receive multiple readings per point, e.g. if frames_per_trigger > 1,
we also create a mean value for the counter signals. These are then pushed to the bpm device
for plotting and further processing. The signal names are defined and mapped in the
self.counter_mapping dictionary & the bpm sub-device.
There are multiple mca channels, each giving individual updates. We want to ensure that
each is updated before we signal that we are ready to read. In future, these signals may
become asynchronous, but we first need to ensure that we can properly combine monitored
signals with async signals for plotting. Until then, we will keep this logic.
"""
with self._rlock:
# Retrieve the signal object which executes this callback
signal = kwargs.get("obj", None)
if signal is None: # This should never happen, but just in case
logger.info(f"Called without 'obj' in kwargs: {kwargs}")
if signal is None:
logger.error(f"Called without 'obj' in kwargs: {kwargs}")
return
# Get the maped signal name from the mapping dictionary
mapped_signal_name = self.counter_mapping.get(signal.name, None)
# If we did not map the signal name in counter_mapping, but receive an update
# we will skip it.
if mapped_signal_name is None:
attr_name = signal.name
mca_channel = getattr(self, attr_name, None)
if mca_channel is None:
logger.error(f"Could not find matching MCA channel for signal {signal.name}")
return
# Push the raw values of the mca channels. The signal name has to be defined
# in the self.mcs sub-device (MCSRaw) to be able to push the values. Otherwise
# we will skip the update.
mca_raw = getattr(self.mcs, signal.name.split("_")[-1], None)
if mca_raw is None:
return
# In case there was more than one value received, i.e. frames_per_trigger > 1,
# we will receive a np.array of values.
if isinstance(value, np.ndarray):
# We push the raw values as a list to the mca_raw signal
# And otherwise compute the mean value for plotting of counter signals
mca_raw.put(value.tolist())
# compute the count_time in seconds
if mapped_signal_name == "count_time":
value = value / self._mcs_clock
value = float(value.mean())
else:
# We received a single value, so we can directly push it
mca_raw.put(value)
# compute the count_time in seconds
if mapped_signal_name == "count_time":
value = value / self._mcs_clock
mca_channel: AsyncSignal
# Get the mapped signal from the bpm device and update it
sig = getattr(self.bpm, mapped_signal_name)
sig.put(value)
self.counter_updated.append(signal.name)
# Once all mca channels have been updated, we can signal that we are ready to read
received_all_updates = set(self.counter_updated) == set(self.counter_mapping.keys())
if received_all_updates:
self.ready_to_read.put(READYTOREAD.DONE)
# The reset of the signal is done in the on_trigger method of ddg1 for the next trigger
self.counter_updated.clear() # Clear the list for the next update cycle
if isinstance(value, np.ndarray):
value = value.tolist() # Convert numpy array to list
else:
value = [value] # Received single value, convert to list
data = {
attr_name: {"value": value, "timestamp": kwargs.get("timestamp") or time.time()}
}
mca_channel.put(data)
# self._received_updates.update(data)
# if len(self._received_updates) == self.num_connected_channels:
# # Send out data on multi async signal
# self.mca.put(self._received_updates, acquisition_group=self._acquisition_group)
# self._received_updates.clear()
def _progress_update(self, value, **kwargs) -> None:
"""Callback for progress updates from ophyd subscription on current_channel."""
# This logic needs to be further refined as this is currently reporting the progress
# of a single trigger from BEC within a burst scan.
frames_per_trigger = self.scan_info.msg.scan_parameters.get("frames_per_trigger", 1)
self.progress.put(
value=value, max_value=frames_per_trigger, done=bool(value == frames_per_trigger)
value=value,
max_value=self._num_total_triggers,
done=bool(value == self._num_total_triggers),
)
def on_stage(self) -> None:
@@ -243,25 +420,21 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard):
"""
self.erase_all.set(1).wait(timeout=self._pv_timeout)
triggers = self.scan_info.msg.scan_parameters.get("frames_per_trigger", 1)
num_points = self.scan_info.msg.num_points
self._num_total_triggers = triggers * num_points
self._acquisition_group = "monitored" if triggers == 1 else "burst_group"
self.preset_real.set(0).wait(timeout=self._pv_timeout)
self.num_use_all.set(triggers).wait(timeout=self._pv_timeout)
# Reset data
self._received_updates.clear()
def on_unstage(self) -> None:
"""
Called when the device is unstaged.
"""
self.stop_all.put(1)
self.ready_to_read.put(READYTOREAD.DONE)
# TODO why 0?
self.erase_all.set(0).wait(timeout=self._pv_timeout)
def on_trigger(self) -> None:
status = TransitionStatus(
self.ready_to_read, strict=True, transitions=[READYTOREAD.PROCESSING, READYTOREAD.DONE]
)
self.cancel_on_stop(status)
return status
def on_pre_scan(self) -> None:
"""
Called before the scan starts.
@@ -279,6 +452,3 @@ class MCSCardCSAXS(PSIDeviceBase, MCSCard):
Called when the scan is stopped.
"""
self.stop_all.put(1)
self.ready_to_read.put(READYTOREAD.DONE)
# Reset the progress signal
# self.progress.put(0, done=True)