from __future__ import annotations import numpy as np from string import Template from queue import Queue import time # import bec_lib from data_processing.stream_processor import StreamProcessor from bec_lib.core import BECMessage, MessageEndpoints from bec_lib.core.redis_connector import MessageObject, RedisConnector from typing import Optional, Tuple class StreamProcessorPx(StreamProcessor): def __init__(self, connector: RedisConnector, config: dict) -> None: """""" super().__init__(connector, config) self.metadata_consumer = None self.metadata = {} # self._init_data_output() self.num_received_msgs = 0 self.queue = Queue() self._init_metadata(endpoint="px_stream/proj_nr") self.start_metadata_consumer(endpoint="px_stream/projection_*/metadata") def _init_metadata(self, endpoint: str) -> None: """Initialize the metadata. Args: endpoint (str): Endpoint for redis topic. Returns: None """ msg = self.producer.get(topic=endpoint) if msg is None: return None msg_raw = BECMessage.DeviceMessage.loads(msg) proj_nr = msg_raw.content["signals"]["proj_nr"] # TODO hardcoded endpoint, possibe to use more general solution? msg = self.producer.get(topic=f"px_stream/projection_{proj_nr}/metadata") msg_raw = BECMessage.DeviceMessage.loads(msg) self._update_queue(msg_raw.content["signals"], proj_nr) def _update_queue(self, metadata: dict, proj_nr: int) -> None: """Update the process queue. Args: metadata (dict): Metadata for the projection. proj_nr (int): Projection number. Returns: None """ self.metadata.update({proj_nr: metadata}) self.queue.put((proj_nr, metadata)) def start_metadata_consumer(self, endpoint: str) -> None: """Start the metadata consumer. Consumer is started with a callback function that updates the metadata. Args: endpoint (str): Endpoint for redis topic. Returns: None """ if self.metadata_consumer and self.metadata_consumer.is_alive(): self.metadata_consumer.shutdown() self.metadata_consumer = self._connector.consumer( pattern=endpoint, cb=self._update_metadata_cb, parent=self ) self.metadata_consumer.start() @staticmethod def _update_metadata_cb(msg: MessageObject, parent: StreamProcessorPx) -> None: """Callback function for the metadata consumer. Args: msg (MessageObject): Message object. parent (StreamProcessorPx): Parent class. Returns: None """ msg_raw = BECMessage.DeviceMessage.loads(msg.value) parent._metadata_msg_handler(msg_raw, msg.topic.decode()) def _metadata_msg_handler(self, msg: BECMessage, topic) -> None: """Handle the metadata message. If self.metadata is larger than 10, the oldest entry is removed. Args: msg (BECMessage): Message object. topic (str): Topic for the message. Returns: None """ if len(self.metadata) > 10: first_key = next(iter(self.metadata)) self.metadata.pop(first_key) proj_nr = int(topic.split("px_stream/projection_")[1].split("/")[0]) self._update_queue(msg.content["signals"], proj_nr) def _init_data_output(self) -> None: """Initialize the data output. Not yet used. Should be used to initialize the output for the processed data. """ self.data = None def start_data_consumer(self) -> None: """function from the parent class that we don't want to use here""" pass def _run_forever(self) -> None: """Loop that runs forever when the processor is started. Upon update of the queue, the data is loaded and processed. This processing continues as long as the queue is empty, and proceeds to the next projection when the queue is updated. Returns: None """ proj_nr, metadata = self.queue.get() self.num_received_msgs = 0 # TODO initiate output, such that self.process only runs on new data # self._init_data_output() data = [] while self.queue.empty(): # TODO debug code for timing # data_msgs = self._get_data(proj_nr) data.extend([msg.content["signals"]["data"] for msg in data_msgs if msg is not None]) # print(f"Loading took {time.time() - start}") start = time.time() result = self.process(data, metadata) print(f"Processing took {time.time() - start}") if not result: continue print(f"Length of data is {result[0][0]['z'].shape}") msg = BECMessage.ProcessedDataMessage(data=result[0][0], metadata=result[1]).dumps() print("Publishing result") self._publish_result(msg) def _get_data(self, proj_nr: int) -> list: """Get data for given proj_nr from redis. Args: proj_nr (int): Projection number. Returns: list: List of azimuthal integrated data. """ msgs = self.producer.lrange( f"px_stream/projection_{proj_nr}/data", self.num_received_msgs, -1 ) if not msgs: return [] self.num_received_msgs += len(msgs) return [BECMessage.DeviceMessage.loads(msg) for msg in msgs] def process(self, data: list, metadata: dict) -> Optional[Tuple[dict, dict]]: """Process the scanning SAXS data Args: data (list): List of azimuthal integrated data. metadata (dict): Metadata for the projection. Returns: Optional[Tuple[dict, dict]]: Processed data and metadata. """ if not data: return None start = time.time() azint_data = np.asarray(data) print(f"Processing took {time.time() - start}") norm_sum = metadata["norm_sum"] q = metadata["q"] out = [] contrast = self.config["parameters"]["contrast"] qranges = self.config["parameters"]["qranges"] aziangles = self.config["parameters"]["aziangles"] f1amp, f2amp, f2phase = self._colorfulplot( qranges=qranges, q=q, norm_sum=norm_sum, data=azint_data, aziangles=aziangles, percentile_value=96, ) if contrast == 0: out = f1amp elif contrast == 1: out = f2amp elif contrast == 2: out = f2phase stream_output = { # 0: {"x": np.asarray(x), "y": np.asarray(y), "z": np.asarray(out)}, 0: {"z": np.asarray(out)}, # "input": self.config["input_xy"], } metadata["grid_scan"] = out.shape return (stream_output, metadata) def _colorfulplot( self, qranges: list, q: np.ndarray, norm_sum: np.ndarray, data: np.ndarray, aziangles: list, percentile_value: int = 96, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Compute data for sSAXS colorful 2D plot. Pending: hsv_to_rgb conversion for colorful output Args: qranges (list): List with q edges for binning. q (np.ndarray): q values. norm_sum (np.ndarray): Normalization sum. data (np.ndarray): Data to be binned. aziangles (list, optional): List of azimuthal angles to shift f2phase. Defaults to None. percentile_value (int, optional): Percentile value for removing outliers above threshold. Defaults to 96, range 0...100. Returns: Tuple[np.ndarray, np.ndarray, np.ndarray]: f1amp, f2amp, f2phase """ output, output_norm = self._bin_qrange(qranges=qranges, q=q, norm_sum=norm_sum, data=data) output_sym = self._sym_data(data=output, norm_sum=output_norm) output_sym = output_sym shape = output_sym.shape[0:2] fft_data = np.fft.rfft(output_sym.reshape((-1, output_sym.shape[-2])), axis=1) if aziangles is None: azi_angle = 0 else: azi_angle = aziangles[0] f1amp = np.abs(fft_data[:, 0]) / output_sym.shape[2] f2amp = 2 * np.abs(fft_data[:, 1]) / output_sym.shape[2] f2angle = np.angle(fft_data[:, 1]) + np.deg2rad(azi_angle) f2phase = (f2angle + np.pi) / (2 * np.pi) f2phase[f2phase > 1] = f2phase[f2phase > 1] - 1 f1amp = f1amp.reshape(shape) f2amp = f2amp.reshape(shape) f2angle = f2angle.reshape(shape) f2phase = f2phase.reshape(shape) h = f2phase max_scale = np.percentile(f2amp, percentile_value) s = f2amp / max_scale s[s > 1] = 1 max_scale = np.percentile(f1amp, percentile_value) v = f1amp v = v / max_scale v[v > 1] = 1 # hsv = np.stack((h, s, v), axis=2) # comb_all = colors.hsv_to_rgb(hsv) return f1amp, f2amp, f2phase # , comb_all def _bin_qrange(self, qranges, q, norm_sum, data) -> Tuple[np.ndarray, np.ndarray]: """Reintegrate data for given q range. Weighted sum for data using norm_sum as weights Args: qranges (list): List with q edges for binning. q (np.ndarray): q values. norm_sum (np.ndarray): Normalization sum. data (np.ndarray): Data to be binned. Returns: np.ndarray: Binned data. np.ndarray: Binned normalization sum. """ output = np.zeros((*data.shape[:-1], len(qranges) - 1)) output_norm = np.zeros((data.shape[-2], len(qranges) - 1)) with np.errstate(divide="ignore", invalid="ignore"): q_mask = np.logical_and(q >= q[qranges[0]], q <= q[qranges[1]]) output_norm[..., 0] = np.nansum(norm_sum[..., q_mask], axis=-1) output[..., 0] = np.nansum(data[..., q_mask] * norm_sum[..., q_mask], axis=-1) output[..., 0] = np.divide( output[..., 0], output_norm[..., 0], out=np.zeros_like(output[..., 0]) ) return output, output_norm def _sym_data(self, data, norm_sum) -> np.ndarray: """Symmetrize data by averaging over the two opposing directions. Helpful to remove detector gaps for x-ray detectors Args: data (np.ndarray): Data to be symmetrized. norm_sum (np.ndarray): Normalization sum. Returns: np.ndarray: Symmetrized data. """ n_directions = norm_sum.shape[0] // 2 output = np.divide( data[..., :n_directions, :] * norm_sum[:n_directions, :] + data[..., n_directions:, :] * norm_sum[n_directions:, :], norm_sum[:n_directions, :] + norm_sum[n_directions:, :], out=np.zeros_like(data[..., :n_directions, :]), ) return output if __name__ == "__main__": config = { "output": "px_dap_worker", "parameters": { # TODO these three inputs could be made available for change from the GUI "qranges": [20, 50], "contrast": 0, "aziangles": None, }, } dap_process = StreamProcessorPx.run(config=config, connector_host=["localhost:6379"])