Files
camserver_sf/configuration/user_scripts/swissfel_spectral_processing.py
2025-04-28 14:35:09 +02:00

235 lines
7.9 KiB
Python

from logging import getLogger
from cam_server.pipeline.data_processing import functions
from cam_server.utils import create_thread_pvs, epics_lock
from collections import deque
import json
import numpy as np
import scipy.signal
import numba
import time
import sys
from threading import Thread
# Configure Numba to use multiple threads
numba.set_num_threads(4)
_logger = getLogger(__name__)
# Shared state globals
global_roi = [0, 0]
initialized = False
sent_pid = -1
buffer = deque(maxlen=5)
channel_pv_names = None
base_pv_names = []
all_pv_names = []
global_ravg_length = 100
ravg_buffers = {}
@numba.njit(parallel=False)
def get_spectrum(image, background):
"""Compute background-subtracted spectrum via row-wise summation."""
y, x = image.shape
profile = np.zeros(x, dtype=np.float64)
for i in numba.prange(y):
for j in range(x):
profile[j] += image[i, j] - background[i, j]
return profile
def update_PVs(buffer, *pv_names):
"""Continuously read from buffer and write to EPICS PVs."""
pvs = create_thread_pvs(list(pv_names))
while True:
time.sleep(0.1)
try:
rec = buffer.popleft()
except IndexError:
continue
try:
for pv, val in zip(pvs, rec):
if pv and pv.connected and (val is not None):
pv.put(val)
except Exception:
_logger.exception("Error updating channels")
def initialize(params):
"""Initialize PV names, running-average settings, and launch update thread."""
global channel_pv_names, base_pv_names, all_pv_names, global_ravg_length
camera = params["camera_name"]
e_int = params["e_int_name"]
e_axis = params["e_axis_name"]
# Fit/result PV names
center_pv = f"{camera}:FIT-COM"
fwhm_pv = f"{camera}:FIT-FWHM"
fit_rms_pv = f"{camera}:FIT-RMS"
fit_res_pv = f"{camera}:FIT-RES"
# ROI PVs for dynamic read
ymin_pv = f"{camera}:SPC_ROI_YMIN"
ymax_pv = f"{camera}:SPC_ROI_YMAX"
axis_pv = e_axis
channel_pv_names = [ymin_pv, ymax_pv, axis_pv]
# Spectrum statistical PV names
com_pv = f"{camera}:SPECT-COM"
std_pv = f"{camera}:SPECT-RMS"
skew_pv = f"{camera}:SPECT-SKEW"
iqr_pv = f"{camera}:SPECT-IQR"
res_pv = f"{camera}:SPECT-RES" # will use IQR-based calc
# Base PVs for update thread (order matters)
base_pv_names = [
e_int, center_pv, fwhm_pv, fit_rms_pv,
fit_res_pv, com_pv, std_pv, skew_pv, iqr_pv, res_pv
]
# Running-average configuration
global_ravg_length = params.get('RAVG_length', global_ravg_length)
# Build list of running-average PVs (exclude e_int, e_axis, processing_parameters)
exclude = {
e_int,
e_axis,
f"{camera}:processing_parameters"
}
ravg_base = [pv for pv in base_pv_names if pv not in exclude]
ravg_pv_names = [pv + '-RAVG' for pv in ravg_base]
# All PVs (original + running average)
all_pv_names = base_pv_names + ravg_pv_names
# Start background thread for PV updates
thread = Thread(target=update_PVs, args=(buffer, *all_pv_names), daemon=True)
thread.start()
def process_image(image, pulse_id, timestamp, x_axis, y_axis, parameters, bsdata=None, background=None):
"""
Main entrypoint: subtract background, crop ROI, smooth, fit Gaussian,
compute metrics, queue PV updates (with running averages for skew and IQR).
Returns a dict of processed PV values (original channels only).
"""
global initialized, sent_pid, channel_pv_names, global_ravg_length, ravg_buffers
try:
if not initialized:
initialize(parameters)
initialized = True
# Dynamic ROI and axis PV read
ymin_pv, ymax_pv, axis_pv = create_thread_pvs(channel_pv_names)
if ymin_pv and ymin_pv.connected:
global_roi[0] = ymin_pv.value
if ymax_pv and ymax_pv.connected:
global_roi[1] = ymax_pv.value
if not (axis_pv and axis_pv.connected):
_logger.warning("Energy axis not connected")
return None
axis = axis_pv.value
if len(axis) < image.shape[1]:
_logger.warning("Energy axis length %d < image width %d", len(axis), image.shape[1])
return None
axis = axis[:image.shape[1]]
# Preprocess image
proc_img = image.astype(np.float32) - np.float32(parameters.get("pixel_bkg", 0))
nrows, _ = proc_img.shape
# Background image
bg_img = parameters.pop('background_data', None)
if not (isinstance(bg_img, np.ndarray) and bg_img.shape == proc_img.shape):
bg_img = None
else:
bg_img = bg_img.astype(np.float32)
# Crop ROI
ymin, ymax = int(global_roi[0]), int(global_roi[1])
if 0 <= ymin < ymax <= nrows:
proc_img = proc_img[ymin:ymax, :]
if bg_img is not None:
bg_img = bg_img[ymin:ymax, :]
# Extract spectrum
spectrum = get_spectrum(proc_img, bg_img) if bg_img is not None else np.sum(proc_img, axis=0)
# Smooth
smoothed = scipy.signal.savgol_filter(spectrum, 51, 3)
# Noise check and fit Gaussian
minimum, maximum = smoothed.min(), smoothed.max()
amplitude = maximum - minimum
skip = amplitude <= nrows * 1.5
offset, amp_fit, center, sigma = functions.gauss_fit_psss(
smoothed[::2], axis[::2], offset=minimum,
amplitude=amplitude, skip=skip, maxfev=10
)
# Compute normalized spectrum weights
sm_norm = smoothed / np.sum(smoothed)
# Statistical moments
spect_com = np.sum(axis * sm_norm)
spect_std = np.sqrt(np.sum((axis - spect_com)**2 * sm_norm))
spect_skew = np.sum((axis - spect_com)**3 * sm_norm) / (spect_std**3)
# Interquartile width (IQR)
cum = np.cumsum(sm_norm)
e25 = np.interp(0.25, cum, axis)
e75 = np.interp(0.75, cum, axis)
spect_iqr = e75 - e25
spect_sum = np.sum(spectrum)
camera = parameters["camera_name"]
# Original result dict
result = {
parameters["e_int_name"]: spectrum,
parameters["e_axis_name"]: axis,
f"{camera}:SPECTRUM_Y_SUM": spect_sum,
f"{camera}:FIT-COM": np.float64(center),
f"{camera}:FIT-FWHM": np.float64(2.355 * sigma),
f"{camera}:FIT-RMS": np.float64(sigma),
f"{camera}:FIT-RES": np.float64(2.355 * sigma / center * 1000),
f"{camera}:SPECT-COM": spect_com,
f"{camera}:SPECT-RMS": spect_std,
f"{camera}:SPECT-SKEW": spect_skew,
f"{camera}:SPECT-IQR": spect_iqr,
# Use IQR for relative spread instead of std
f"{camera}:SPECT-RES": np.float64(spect_iqr / spect_com * 1000),
f"{camera}:processing_parameters": json.dumps({"roi": global_roi})
}
# Prepare full values for PV update (including running averages)
exclude = {
parameters["e_int_name"],
parameters["e_axis_name"],
f"{camera}:processing_parameters"
}
ravg_results = {}
for base_pv in (pv for pv in base_pv_names if pv not in exclude):
buf = ravg_buffers.setdefault(base_pv, deque(maxlen=global_ravg_length))
buf.append(result.get(base_pv))
ravg_results[f"{base_pv}-RAVG"] = np.mean(buf)
# Merge for PV write
full_results = {**result, **ravg_results}
# Queue PV update if new pulse
if epics_lock.acquire(False):
try:
if pulse_id > sent_pid:
sent_pid = pulse_id
entry = tuple(full_results.get(pv) for pv in all_pv_names)
buffer.append(entry)
finally:
epics_lock.release()
return full_results
except Exception as ex:
_logger.warning("Processing error: %s", ex)
return {}