diff --git a/bec_plugins/data_processing/px_example.py b/bec_plugins/data_processing/px_example.py index c42876f..88fac4d 100644 --- a/bec_plugins/data_processing/px_example.py +++ b/bec_plugins/data_processing/px_example.py @@ -13,71 +13,159 @@ from typing import Optional, Tuple class StreamProcessorPx(StreamProcessor): def __init__(self, connector: RedisConnector, config: dict) -> None: - """ - Initialize the LmfitProcessor class. - - Args: - connector (RedisConnector): Redis connector. - config (dict): Configuration for the processor. - """ + """""" super().__init__(connector, config) self.metadata_consumer = None self.metadata = {} + # self._init_data_output() self.num_received_msgs = 0 self.queue = Queue() - self.start_metadata_consumer() + self._init_metadata(endpoint="px_stream/proj_nr") + self.start_metadata_consumer(endpoint="px_stream/projection_*/metadata") - def start_data_consumer(self): - pass + def _init_metadata(self, endpoint: str) -> None: + """Initialize the metadata. + + Args: + endpoint (str): Endpoint for redis topic. + + Returns: + None + + """ + + msg = self.producer.get(topic=endpoint) + if msg is None: + return None + msg_raw = BECMessage.DeviceMessage.loads(msg) + proj_nr = msg_raw.content["signals"]["proj_nr"] + # TODO hardcoded endpoint, possibe to use more general solution? + msg = self.producer.get(topic=f"px_stream/projection_{proj_nr}/metadata") + msg_raw = BECMessage.DeviceMessage.loads(msg) + self._update_queue(msg_raw.content["signals"], proj_nr) + + def _update_queue(self, metadata: dict, proj_nr: int) -> None: + """Update the process queue. + + Args: + metadata (dict): Metadata for the projection. + proj_nr (int): Projection number. + + Returns: + None + + """ + + self.metadata.update({proj_nr: metadata}) + self.queue.put((proj_nr, metadata)) + + def start_metadata_consumer(self, endpoint: str) -> None: + """Start the metadata consumer. + Consumer is started with a callback function that updates the metadata. + + Args: + endpoint (str): Endpoint for redis topic. + + Returns: + None + + """ - def start_metadata_consumer(self): if self.metadata_consumer and self.metadata_consumer.is_alive(): self.metadata_consumer.shutdown() self.metadata_consumer = self._connector.consumer( - pattern="px_stream/projection_*/metadata", cb=self._update_metadata, parent=self + pattern=endpoint, cb=self._update_metadata_cb, parent=self ) self.metadata_consumer.start() @staticmethod - def _update_metadata(msg: MessageObject, parent: StreamProcessorPx) -> None: + def _update_metadata_cb(msg: MessageObject, parent: StreamProcessorPx) -> None: + """Callback function for the metadata consumer. + + Args: + msg (MessageObject): Message object. + parent (StreamProcessorPx): Parent class. + + Returns: + None + + """ + msg_raw = BECMessage.DeviceMessage.loads(msg.value) parent._metadata_msg_handler(msg_raw, msg.topic.decode()) - def _metadata_msg_handler(self, msg, topic): + def _metadata_msg_handler(self, msg: BECMessage, topic) -> None: + """Handle the metadata message. + If self.metadata is larger than 10, the oldest entry is removed. + + Args: + msg (BECMessage): Message object. + topic (str): Topic for the message. + + Returns: + None + + """ + if len(self.metadata) > 10: first_key = next(iter(self.metadata)) self.metadata.pop(first_key) proj_nr = int(topic.split("px_stream/projection_")[1].split("/")[0]) - self.metadata.update({proj_nr: msg.content["signals"]}) - self.queue.put((proj_nr, msg.content["signals"])) + self._update_queue(msg.content["signals"], proj_nr) + + def _init_data_output(self) -> None: + """Initialize the data output. + Not yet used. Should be used to initialize the output for the processed data. + """ + self.data = None + + def start_data_consumer(self) -> None: + """function from the parent class that we don't want to use here""" + pass + + def _run_forever(self) -> None: + """Loop that runs forever when the processor is started. + Upon update of the queue, the data is loaded and processed. + This processing continues as long as the queue is empty, + and proceeds to the next projection when the queue is updated. + + Returns: + None + + """ - def _run_forever(self): - """""" - # TODO: Check if should skip entries in queue at beginning proj_nr, metadata = self.queue.get() self.num_received_msgs = 0 + # TODO initiate output, such that self.process only runs on new data + # self._init_data_output() data = [] while self.queue.empty(): - start = time.time() + # TODO debug code for timing + # data_msgs = self._get_data(proj_nr) data.extend([msg.content["signals"]["data"] for msg in data_msgs if msg is not None]) - #if len(data) > 80: - # out = np.asarray(data) - # result = ({0: {"z": np.sum(out, axis=(-1, -2))}}, {1: {}}) - #else: - # continue - print(f"Loading took {time.time() - start}") + # print(f"Loading took {time.time() - start}") + start = time.time() result = self.process(data, metadata) print(f"Processing took {time.time() - start}") if not result: continue print(f"Length of data is {result[0][0]['z'].shape}") - msg = BECMessage.ProcessedDataMessage( - data=result[0][0], metadata=result[1]).dumps() + msg = BECMessage.ProcessedDataMessage(data=result[0][0], metadata=result[1]).dumps() print("Publishing result") self._publish_result(msg) def _get_data(self, proj_nr: int) -> list: + """Get data for given proj_nr from redis. + + Args: + proj_nr (int): Projection number. + + Returns: + list: List of azimuthal integrated data. + + """ + msgs = self.producer.lrange( f"px_stream/projection_{proj_nr}/data", self.num_received_msgs, -1 ) @@ -87,27 +175,36 @@ class StreamProcessorPx(StreamProcessor): return [BECMessage.DeviceMessage.loads(msg) for msg in msgs] def process(self, data: list, metadata: dict) -> Optional[Tuple[dict, dict]]: + """Process the scanning SAXS data + + Args: + data (list): List of azimuthal integrated data. + metadata (dict): Metadata for the projection. + + Returns: + Optional[Tuple[dict, dict]]: Processed data and metadata. + + """ + if not data: return None - # get the event data, hard coded + start = time.time() azint_data = np.asarray(data) + print(f"Processing took {time.time() - start}") norm_sum = metadata["norm_sum"] q = metadata["q"] out = [] - ##################################### - # Pick contrast 0:f1amp, 1:f2amp, 2:f2phase contrast = self.config["parameters"]["contrast"] - # user input/LinearRegionROI, maybe move to metadata qranges = self.config["parameters"]["qranges"] - ##################################### + aziangles = self.config["parameters"]["aziangles"] f1amp, f2amp, f2phase = self._colorfulplot( qranges=qranges, q=q, norm_sum=norm_sum, data=azint_data, - aziangles=None, + aziangles=aziangles, percentile_value=96, ) if contrast == 0: @@ -126,46 +223,31 @@ class StreamProcessorPx(StreamProcessor): return (stream_output, metadata) - def _bin_qrange(self, qranges, q, norm_sum, data): - """ + def _colorfulplot( + self, + qranges: list, + q: np.ndarray, + norm_sum: np.ndarray, + data: np.ndarray, + aziangles: list, + percentile_value: int = 96, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute data for sSAXS colorful 2D plot. + Pending: hsv_to_rgb conversion for colorful output + Args: - q ranges: list with indices of q edges, data is binned between neighbors. - q: all q - norm_sum: weights for q - data: full data + qranges (list): List with q edges for binning. + q (np.ndarray): q values. + norm_sum (np.ndarray): Normalization sum. + data (np.ndarray): Data to be binned. + aziangles (list, optional): List of azimuthal angles to shift f2phase. Defaults to None. + percentile_value (int, optional): Percentile value for removing outliers above threshold. Defaults to 96, range 0...100. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray]: f1amp, f2amp, f2phase + """ - output = np.zeros((*data.shape[:-1], len(qranges) - 1)) - output_norm = np.zeros((data.shape[-2], len(qranges) - 1)) - with np.errstate(divide="ignore", invalid="ignore"): - for ii, qval in enumerate(qranges[:-1]): - q_mask = np.logical_and(q >= q[qranges[ii]], q < q[qranges[ii + 1]]) - output_norm[..., ii] = np.nansum(norm_sum[..., q_mask], axis=-1) - output[..., ii] = np.nansum(data[..., q_mask] * norm_sum[..., q_mask], axis=-1) - output[..., ii] = np.divide( - output[..., ii], output_norm[..., ii], out=np.zeros_like(output[..., ii]) - ) - - return output, output_norm - - def _sym_data(self, data, norm_sum): - n_directions = norm_sum.shape[0] // 2 - output = np.divide( - data[..., :n_directions, :] * norm_sum[:n_directions, :] - + data[..., n_directions:, :] * norm_sum[n_directions:, :], - norm_sum[:n_directions, :] + norm_sum[n_directions:, :], - out=np.zeros_like(data[..., :n_directions, :]), - ) - return output - - def _colorfulplot(self, qranges, q, norm_sum, data, aziangles=None, percentile_value=96): - """ - Args: - q ranges: list with 2 indices for q edges - q: all q - norm_sum: weights for q - data: full data - """ output, output_norm = self._bin_qrange(qranges=qranges, q=q, norm_sum=norm_sum, data=data) output_sym = self._sym_data(data=output, norm_sum=output_norm) output_sym = output_sym @@ -179,10 +261,8 @@ class StreamProcessorPx(StreamProcessor): f1amp = np.abs(fft_data[:, 0]) / output_sym.shape[2] f2amp = 2 * np.abs(fft_data[:, 1]) / output_sym.shape[2] - # This still slightly confused me and to get mapping to colorwheel correct it needs to match f2angle = np.angle(fft_data[:, 1]) + np.deg2rad(azi_angle) - # Unwrap phaseand normalize between 0...1, output of rfft is in between -pi and pi, it can be larger then 1 because of addition of zero angle! f2phase = (f2angle + np.pi) / (2 * np.pi) f2phase[f2phase > 1] = f2phase[f2phase > 1] - 1 @@ -191,7 +271,6 @@ class StreamProcessorPx(StreamProcessor): f2angle = f2angle.reshape(shape) f2phase = f2phase.reshape(shape) - # hsv output h = f2phase max_scale = np.percentile(f2amp, percentile_value) @@ -203,22 +282,70 @@ class StreamProcessorPx(StreamProcessor): v = v / max_scale v[v > 1] = 1 - hsv = np.stack((h, s, v), axis=2) + # hsv = np.stack((h, s, v), axis=2) # comb_all = colors.hsv_to_rgb(hsv) return f1amp, f2amp, f2phase # , comb_all + def _bin_qrange(self, qranges, q, norm_sum, data) -> Tuple[np.ndarray, np.ndarray]: + """Reintegrate data for given q range. + Weighted sum for data using norm_sum as weights + + Args: + qranges (list): List with q edges for binning. + q (np.ndarray): q values. + norm_sum (np.ndarray): Normalization sum. + data (np.ndarray): Data to be binned. + + Returns: + np.ndarray: Binned data. + np.ndarray: Binned normalization sum. + """ + + output = np.zeros((*data.shape[:-1], len(qranges) - 1)) + output_norm = np.zeros((data.shape[-2], len(qranges) - 1)) + + with np.errstate(divide="ignore", invalid="ignore"): + q_mask = np.logical_and(q >= q[qranges[0]], q <= q[qranges[1]]) + output_norm[..., 0] = np.nansum(norm_sum[..., q_mask], axis=-1) + output[..., 0] = np.nansum(data[..., q_mask] * norm_sum[..., q_mask], axis=-1) + output[..., 0] = np.divide( + output[..., 0], output_norm[..., 0], out=np.zeros_like(output[..., 0]) + ) + + return output, output_norm + + def _sym_data(self, data, norm_sum) -> np.ndarray: + """Symmetrize data by averaging over the two opposing directions. + Helpful to remove detector gaps for x-ray detectors + + Args: + data (np.ndarray): Data to be symmetrized. + norm_sum (np.ndarray): Normalization sum. + + Returns: + np.ndarray: Symmetrized data. + + """ + + n_directions = norm_sum.shape[0] // 2 + output = np.divide( + data[..., :n_directions, :] * norm_sum[:n_directions, :] + + data[..., n_directions:, :] * norm_sum[n_directions:, :], + norm_sum[:n_directions, :] + norm_sum[n_directions:, :], + out=np.zeros_like(data[..., :n_directions, :]), + ) + return output + if __name__ == "__main__": config = { - "template": Template("px_stream/projection_$proj/$channel"), - "channels": ["data", "metadata", "q", "norm_sum"], "output": "px_dap_worker", "parameters": { - "qranges": [20, 50], # TODO this will be signal from ROI selector - "contrast": 0, # "contrast_stream" : 'px_contrast_stream', + # TODO these three inputs could be made available for change from the GUI + "qranges": [20, 50], + "contrast": 0, + "aziangles": None, }, } dap_process = StreamProcessorPx.run(config=config, connector_host=["localhost:6379"]) - # dap_process = StreamProcessorPx(config=config, connector_host=["localhost:6379"]) - # dap_process.start_met