initial hkl mesh scan

This commit is contained in:
gac-x04sa
2025-04-25 15:07:35 +02:00
parent 6e0a3ac65e
commit 25cd9330dd
2 changed files with 268 additions and 1 deletions

View File

@@ -1,2 +1,3 @@
from .hkl_scan import HklScan
from .fly_scan import HklFlyScan
from .fly_scan import HklFlyScan
from .mesh_scan import HklMeshScan, HklMeshFlyScan

View File

@@ -0,0 +1,266 @@
import itertools
import numpy
# from bec_lib.endpoints import MessageEndpoints
from typing import Literal
from bec_lib import messages
from bec_lib.endpoints import MessageEndpoints
from bec_lib.devicemanager import DeviceManagerBase
from bec_lib.logger import bec_logger
# from bec_lib import messages
# from bec_server.scan_server.errors import ScanAbortion
from bec_server.scan_server.scans import ScanAbortion, ScanArgType, ScanBase, AsyncFlyScanBase
logger = bec_logger.logger
class HklMeshScan(ScanBase):
scan_name = 'hklmesh_scan'
arg_input = {
'index': ScanArgType.DEVICE,
'start': ScanArgType.FLOAT,
'stop': ScanArgType.FLOAT,
'points': ScanArgType.INT
}
arg_bundle_size = {"bundle": len(arg_input), "min": 2, "max": 2}
required_kwargs = ['diffract', 'exp_time']
def __init__(self, *args, diffract: str, **kwargs):
self.diffract = diffract
super().__init__(**kwargs)
if any(m not in ['h', 'k', 'l'] for m in self.caller_args):
raise ValueError("Invalid name. Must be 'h', 'k', or 'l'.")
if len(self.caller_args) != 2:
raise ValueError("Only 2 names can be given.")
self.scan_report_devices = ['h', 'k', 'l'] + self.scan_motors + self.readout_priority['monitored']
def update_scan_motors(self):
self.scan_motors = self.device_manager.devices[self.diffract].real_axes.get()
def prepare_positions(self):
"""
Override base method to yield from _calculate_position method
"""
yield from self._calculate_positions()
self._optimize_trajectory()
self.num_pos = len(self.positions) * self.burst_at_each_point
yield from self._set_position_offset()
self._check_limits()
def _calculate_positions(self):
inner_name, outer_name = [x for x in self.caller_args if x in ['h','k','l']]
fixed_name = [x for x in ['h', 'k', 'l'] if x not in self.caller_args][0]
fixed_value = yield from self.stubs.send_rpc_and_wait(self.diffract, f'{fixed_name}.position')
print(f"{inner_name=}, {outer_name=}, {fixed_name=}")
inner_ind, outer_ind, fixed_ind = (['h', 'k', 'l'].index(x) for x in [inner_name, outer_name, fixed_name])
hkls = []
outer_start, outer_stop, outer_points = self.caller_args[outer_name]
inner_start, inner_stop, inner_points = self.caller_args[inner_name]
outer_vect = numpy.linspace(outer_start, outer_stop, outer_points, dtype=float)
for i, outer in enumerate(outer_vect):
if i % 2 == 0:
inner_vect = numpy.linspace(inner_start, inner_stop, inner_points, dtype=float)
else:
inner_vect = numpy.linspace(inner_stop, inner_start, inner_points, dtype=float)
for inner in inner_vect:
hkl = [0, 0, 0]
hkl[inner_ind] = inner
hkl[outer_ind] = outer
hkl[fixed_ind] = fixed_value
hkls.append(hkl)
positions = yield from self.stubs.send_rpc_and_wait(self.diffract, 'angles_from_hkls', hkls)
# the last two positions 'betaIn' and 'betaOut' are not real motors
self.positions = []
for position in positions:
self.positions.append(position[:-2])
class HklMeshFlyScan(AsyncFlyScanBase):
scan_name = 'hklmesh_flyscan'
scan_type = 'fly'
arg_input = {
'index': ScanArgType.DEVICE,
'start': ScanArgType.FLOAT,
'stop': ScanArgType.FLOAT,
'points': ScanArgType.INT
}
arg_bundle_size = {"bundle": len(arg_input), "min": 2, "max": 2}
required_kwargs = ['diffract', 'controller']
use_scan_progress_report = False
def __init__(self, *args, diffract: str, controller: str, **kwargs):
self.diffract = diffract
self.controller = controller
super().__init__(**kwargs)
if any(m not in ['h', 'k', 'l'] for m in self.caller_args):
raise ValueError("Invalid name. Must be 'h', 'k', or 'l'.")
if len(self.caller_args) != 2:
raise ValueError("Only 2 names can be given.")
self.scan_report_devices = ['h', 'k', 'l'] + self.scan_motors + self.readout_priority['monitored']
@property
def monitor_sync(self):
return self.diffract
def update_scan_motors(self):
self.scan_motors = self.device_manager.devices[self.diffract].real_axes.get()
def update_readout_priority(self):
self.readout_priority['async'].extend(['h', 'k', 'l'])
self.readout_priority['async'].extend(self.scan_motors)
def prepare_positions(self):
"""
Override base method to yield from _calculate_position method
"""
yield from self._calculate_positions()
self._optimize_trajectory()
self.num_pos = len(self.positions) * self.burst_at_each_point
yield from self._set_position_offset()
self._check_limits()
def _calculate_positions(self):
inner_name, outer_name = [x for x in self.caller_args if x in ['h','k','l']]
fixed_name = [x for x in ['h', 'k', 'l'] if x not in self.caller_args][0]
fixed_value = yield from self.stubs.send_rpc_and_wait(self.diffract, f'{fixed_name}.position')
inner_ind, outer_ind, fixed_ind = (['h', 'k', 'l'].index(x) for x in [inner_name, outer_name, fixed_name])
hkls = []
outer_start, outer_stop, outer_points = self.caller_args[outer_name]
inner_start, inner_stop, inner_points = self.caller_args[inner_name]
outer_vect = numpy.linspace(outer_start, outer_stop, outer_points, dtype=float)
for i, outer in enumerate(outer_vect):
if i % 2 == 0:
inner_vect = numpy.linspace(inner_start, inner_stop, inner_points, dtype=float)
else:
inner_vect = numpy.linspace(inner_stop, inner_start, inner_points, dtype=float)
for inner in inner_vect:
hkl = [0, 0, 0]
hkl[inner_ind] = inner
hkl[outer_ind] = outer
hkl[fixed_ind] = fixed_value
hkls.append(hkl)
positions = yield from self.stubs.send_rpc_and_wait(self.diffract, 'angles_from_hkls', hkls)
# the last two positions 'betaIn' and 'betaOut' are not real motors
self.positions = []
for position in positions:
self.positions.append(position[:-2])
def scan_core(self):
all_motors = self.device_manager.devices[self.controller].axes.get()
inner_name, outer_name = [x for x in self.caller_args if x in ['h','k','l']]
fixed_name = [x for x in ['h', 'k', 'l'] if x not in self.caller_args][0]
fixed_value = yield from self.stubs.send_rpc_and_wait(self.diffract, f'{fixed_name}.position')
inner_ind, outer_ind, fixed_ind = (['h', 'k', 'l'].index(x) for x in [inner_name, outer_name, fixed_name])
num_pos = 0
hkls = []
outer_start, outer_stop, outer_points = self.caller_args[outer_name]
inner_start, inner_stop, inner_points = self.caller_args[inner_name]
outer_vect = numpy.linspace(outer_start, outer_stop, outer_points, dtype=float)
for i, outer in enumerate(outer_vect):
if i % 2 == 0:
start, stop = inner_start, inner_stop
else:
start, stop = inner_stop, inner_start
hkls = [[0, 0, 0], [0, 0, 0]]
hkls[0][inner_ind] = start
hkls[0][outer_ind] = outer
hkls[0][fixed_ind] = fixed_value
hkls[1][inner_ind] = stop
hkls[1][outer_ind] = outer
hkls[1][fixed_ind] = fixed_value
positions = yield from self.stubs.send_rpc_and_wait(self.diffract, 'angles_from_hkls', hkls)
# Move motors to start position, the last two positions 'betaIn' and 'betaOut' are not real motors
yield from self.stubs.set(device=self.scan_motors, value=positions[0][:-2])
inner_time = inner_points * self.exp_time
yield from self.stubs.send_rpc_and_wait(self.controller, 'num_points.put', 2)
yield from self.stubs.send_rpc_and_wait(self.controller, 'num_pulses.put', inner_points)
yield from self.stubs.send_rpc_and_wait(self.controller, 'start_pulses.put', 1)
yield from self.stubs.send_rpc_and_wait(self.controller, 'end_pulses.put', 2)
yield from self.stubs.send_rpc_and_wait(self.controller, 'time_mode.put', 1)
print(f"{inner_time=} {inner_points=} {self.exp_time=}")
yield from self.stubs.send_rpc_and_wait(self.controller, 'times.put', [inner_time, inner_time])
for axis_name in all_motors:
if axis_name not in self.scan_motors:
yield from self.stubs.send_rpc_and_wait(self.controller, f'{axis_name}.use_axis.put', 0)
else:
index = self.scan_motors.index(axis_name)
yield from self.stubs.send_rpc_and_wait(self.controller,
f'{axis_name}.positions.put',
(positions[0][index], positions[-1][index]))
yield from self.stubs.send_rpc_and_wait(self.controller, f'{axis_name}.use_axis.put', 1)
yield from self.stubs.send_rpc_and_wait(self.controller, 'build_profile')
build_status = yield from self.stubs.send_rpc_and_wait(self.controller, 'build_status.get')
if build_status != 1:
raise ScanAbortion('Profile build failed')
yield from self.stubs.send_rpc_and_wait(self.controller, 'execute_profile')
execute_status = yield from self.stubs.send_rpc_and_wait(self.controller, 'execute_status.get')
if execute_status != 1:
raise ScanAbortion('Profile execute failed')
yield from self.stubs.send_rpc_and_wait(self.controller, 'readback_profile')
readback_status = yield from self.stubs.send_rpc_and_wait(self.controller,'readback_status.get')
if readback_status != 1:
raise ScanAbortion('Profile readback failed')
angle_readbacks = []
for index, axis_name in enumerate(self.scan_motors):
readbacks = yield from self.stubs.send_rpc_and_wait(self.controller, f'{axis_name}.readbacks.get')
self._publish_readbacks(axis_name, readbacks)
angle_readbacks.append(readbacks)
# motor readbacks are aranged column-wise
angle_readbacks = [list(x) for x in zip(*angle_readbacks)]
hkls = yield from self.stubs.send_rpc_and_wait(self.diffract, 'hkls_from_angles', angle_readbacks)
hkls = numpy.array(hkls)
for index, name in enumerate(['h', 'k', 'l', 'betaIn', 'betaOut']):
self._publish_readbacks(name, hkls[:, index])
# motor readbacks have more points than generated pulses
num_pos += len(angle_readbacks)
self.num_pos = num_pos
logger.success(f'{self.scan_name} finished')
def _publish_readbacks(self, device, readbacks):
metadata = {"async_update": "append", "max_shape": [None, None]}
msg = messages.DeviceMessage(
signals={device: {'value': readbacks} }, metadata=metadata
)
self.stubs.connector.xadd(
topic=MessageEndpoints.device_async_readback(
scan_id=self.scan_id, device=device
),
msg_dict={"data": msg},
expire=1800,
)