from __future__ import annotations import numpy as np from string import Template from queue import Queue import time # import bec_lib from data_processing.stream_processor import StreamProcessor from bec_lib.core import BECMessage, MessageEndpoints from bec_lib.core.redis_connector import MessageObject, RedisConnector from typing import Optional, Tuple class StreamProcessorPx(StreamProcessor): def __init__(self, connector: RedisConnector, config: dict) -> None: """ Initialize the LmfitProcessor class. Args: connector (RedisConnector): Redis connector. config (dict): Configuration for the processor. """ super().__init__(connector, config) self.metadata_consumer = None self.metadata = {} self.num_received_msgs = 0 self.queue = Queue() self.start_metadata_consumer() def start_data_consumer(self): pass def start_metadata_consumer(self): if self.metadata_consumer and self.metadata_consumer.is_alive(): self.metadata_consumer.shutdown() self.metadata_consumer = self._connector.consumer( pattern="px_stream/projection_*/metadata", cb=self._update_metadata, parent=self ) self.metadata_consumer.start() @staticmethod def _update_metadata(msg: MessageObject, parent: StreamProcessorPx) -> None: msg_raw = BECMessage.DeviceMessage.loads(msg.value) parent._metadata_msg_handler(msg_raw, msg.topic.decode()) def _metadata_msg_handler(self, msg, topic): if len(self.metadata) > 10: first_key = next(iter(self.metadata)) self.metadata.pop(first_key) proj_nr = int(topic.split("px_stream/projection_")[1].split("/")[0]) self.metadata.update({proj_nr: msg.content["signals"]}) self.queue.put((proj_nr, msg.content["signals"])) def _run_forever(self): """""" # TODO: Check if should skip entries in queue at beginning proj_nr, metadata = self.queue.get() self.num_received_msgs = 0 data = [] while self.queue.empty(): start = time.time() data_msgs = self._get_data(proj_nr) data.extend([msg.content["signals"]["data"] for msg in data_msgs if msg is not None]) #if len(data) > 80: # out = np.asarray(data) # result = ({0: {"z": np.sum(out, axis=(-1, -2))}}, {1: {}}) #else: # continue print(f"Loading took {time.time() - start}") result = self.process(data, metadata) print(f"Processing took {time.time() - start}") if not result: continue print(f"Length of data is {result[0][0]['z'].shape}") msg = BECMessage.ProcessedDataMessage( data=result[0][0], metadata=result[1]).dumps() print("Publishing result") self._publish_result(msg) def _get_data(self, proj_nr: int) -> list: msgs = self.producer.lrange( f"px_stream/projection_{proj_nr}/data", self.num_received_msgs, -1 ) if not msgs: return [] self.num_received_msgs += len(msgs) return [BECMessage.DeviceMessage.loads(msg) for msg in msgs] def process(self, data: list, metadata: dict) -> Optional[Tuple[dict, dict]]: if not data: return None # get the event data, hard coded azint_data = np.asarray(data) norm_sum = metadata["norm_sum"] q = metadata["q"] out = [] ##################################### # Pick contrast 0:f1amp, 1:f2amp, 2:f2phase contrast = self.config["parameters"]["contrast"] # user input/LinearRegionROI, maybe move to metadata qranges = self.config["parameters"]["qranges"] ##################################### f1amp, f2amp, f2phase = self._colorfulplot( qranges=qranges, q=q, norm_sum=norm_sum, data=azint_data, aziangles=None, percentile_value=96, ) if contrast == 0: out = f1amp elif contrast == 1: out = f2amp elif contrast == 2: out = f2phase stream_output = { # 0: {"x": np.asarray(x), "y": np.asarray(y), "z": np.asarray(out)}, 0: {"z": np.asarray(out)}, # "input": self.config["input_xy"], } metadata["grid_scan"] = out.shape return (stream_output, metadata) def _bin_qrange(self, qranges, q, norm_sum, data): """ Args: q ranges: list with indices of q edges, data is binned between neighbors. q: all q norm_sum: weights for q data: full data """ output = np.zeros((*data.shape[:-1], len(qranges) - 1)) output_norm = np.zeros((data.shape[-2], len(qranges) - 1)) with np.errstate(divide="ignore", invalid="ignore"): for ii, qval in enumerate(qranges[:-1]): q_mask = np.logical_and(q >= q[qranges[ii]], q < q[qranges[ii + 1]]) output_norm[..., ii] = np.nansum(norm_sum[..., q_mask], axis=-1) output[..., ii] = np.nansum(data[..., q_mask] * norm_sum[..., q_mask], axis=-1) output[..., ii] = np.divide( output[..., ii], output_norm[..., ii], out=np.zeros_like(output[..., ii]) ) return output, output_norm def _sym_data(self, data, norm_sum): n_directions = norm_sum.shape[0] // 2 output = np.divide( data[..., :n_directions, :] * norm_sum[:n_directions, :] + data[..., n_directions:, :] * norm_sum[n_directions:, :], norm_sum[:n_directions, :] + norm_sum[n_directions:, :], out=np.zeros_like(data[..., :n_directions, :]), ) return output def _colorfulplot(self, qranges, q, norm_sum, data, aziangles=None, percentile_value=96): """ Args: q ranges: list with 2 indices for q edges q: all q norm_sum: weights for q data: full data """ output, output_norm = self._bin_qrange(qranges=qranges, q=q, norm_sum=norm_sum, data=data) output_sym = self._sym_data(data=output, norm_sum=output_norm) output_sym = output_sym shape = output_sym.shape[0:2] fft_data = np.fft.rfft(output_sym.reshape((-1, output_sym.shape[-2])), axis=1) if aziangles is None: azi_angle = 0 else: azi_angle = aziangles[0] f1amp = np.abs(fft_data[:, 0]) / output_sym.shape[2] f2amp = 2 * np.abs(fft_data[:, 1]) / output_sym.shape[2] # This still slightly confused me and to get mapping to colorwheel correct it needs to match f2angle = np.angle(fft_data[:, 1]) + np.deg2rad(azi_angle) # Unwrap phaseand normalize between 0...1, output of rfft is in between -pi and pi, it can be larger then 1 because of addition of zero angle! f2phase = (f2angle + np.pi) / (2 * np.pi) f2phase[f2phase > 1] = f2phase[f2phase > 1] - 1 f1amp = f1amp.reshape(shape) f2amp = f2amp.reshape(shape) f2angle = f2angle.reshape(shape) f2phase = f2phase.reshape(shape) # hsv output h = f2phase max_scale = np.percentile(f2amp, percentile_value) s = f2amp / max_scale s[s > 1] = 1 max_scale = np.percentile(f1amp, percentile_value) v = f1amp v = v / max_scale v[v > 1] = 1 hsv = np.stack((h, s, v), axis=2) # comb_all = colors.hsv_to_rgb(hsv) return f1amp, f2amp, f2phase # , comb_all if __name__ == "__main__": config = { "template": Template("px_stream/projection_$proj/$channel"), "channels": ["data", "metadata", "q", "norm_sum"], "output": "px_dap_worker", "parameters": { "qranges": [20, 50], # TODO this will be signal from ROI selector "contrast": 0, # "contrast_stream" : 'px_contrast_stream', }, } dap_process = StreamProcessorPx.run(config=config, connector_host=["localhost:6379"]) # dap_process = StreamProcessorPx(config=config, connector_host=["localhost:6379"]) # dap_process.start_met