diff --git a/dap/algos/forcesend.py b/dap/algos/forcesend.py index 73d19c2..1c243b4 100644 --- a/dap/algos/forcesend.py +++ b/dap/algos/forcesend.py @@ -4,30 +4,29 @@ from .mask import calc_mask_pixels from .thresh import threshold -def calc_force_send(results, data, pixel_mask_pf, image, data_summed, n_aggregated_images): +def calc_force_send(results, data, pixel_mask_pf, image, aggregator): force_send_visualisation = False if data.dtype == np.uint16: - return data, force_send_visualisation, data_summed, n_aggregated_images + return data, force_send_visualisation, aggregator apply_aggregation = results.get("apply_aggregation", False) apply_threshold = results.get("apply_threshold", False) if not apply_aggregation: - data_summed = None - n_aggregated_images = 0 + aggregator.reset() if not apply_aggregation and not apply_threshold: data = image - return data, force_send_visualisation, data_summed, n_aggregated_images + return data, force_send_visualisation, aggregator calc_apply_threshold(results, data) # changes data in place - data, force_send_visualisation, data_summed, n_aggregated_images = calc_apply_aggregation(results, data, data_summed, n_aggregated_images) + data, force_send_visualisation, aggregator = calc_apply_aggregation(results, data, aggregator) calc_mask_pixels(data, pixel_mask_pf) # changes data in place - return data, force_send_visualisation, data_summed, n_aggregated_images + return data, force_send_visualisation, aggregator @@ -48,34 +47,29 @@ def calc_apply_threshold(results, data): -def calc_apply_aggregation(results, data, data_summed, n_aggregated_images): +def calc_apply_aggregation(results, data, aggregator): force_send_visualisation = False apply_aggregation = results.get("apply_aggregation", False) if not apply_aggregation: - return data, force_send_visualisation, data_summed, n_aggregated_images + return data, force_send_visualisation, aggregator if "aggregation_max" not in results: - return data, force_send_visualisation, data_summed, n_aggregated_images + return data, force_send_visualisation, aggregator - if data_summed is None: - data_summed = data.copy() - n_aggregated_images = 1 - else: - data_summed += data - n_aggregated_images += 1 + aggregator += data - data = data_summed + data = aggregator.data + n_aggregated_images = aggregator.counter results["aggregated_images"] = n_aggregated_images results["worker"] = 1 #TODO: keep this for backwards compatibility? if n_aggregated_images >= results["aggregation_max"]: force_send_visualisation = True - data_summed = None - n_aggregated_images = 0 + aggregator.reset() - return data, force_send_visualisation, data_summed, n_aggregated_images + return data, force_send_visualisation, aggregator diff --git a/dap/utils/__init__.py b/dap/utils/__init__.py index 4474c56..45f8cbe 100644 --- a/dap/utils/__init__.py +++ b/dap/utils/__init__.py @@ -1,4 +1,5 @@ +from .aggregator import Aggregator from .bits import read_bit from .bufjson import BufferedJSON diff --git a/dap/utils/aggregator.py b/dap/utils/aggregator.py new file mode 100644 index 0000000..bd6d315 --- /dev/null +++ b/dap/utils/aggregator.py @@ -0,0 +1,26 @@ + +class Aggregator: + + def __init__(self): + self.reset() + + def reset(self): + self.data = None + self.counter = 0 + + def add(self, item): + if self.data is None: + self.data = item.copy() + self.counter = 1 + else: + self.data += item + self.counter += 1 + return self + + __iadd__ = add + + def __repr__(self): + return f"{self.data!r} / {self.counter}" + + + diff --git a/dap/worker.py b/dap/worker.py index 3919c33..9f8c9f1 100644 --- a/dap/worker.py +++ b/dap/worker.py @@ -4,7 +4,7 @@ from random import randint import numpy as np from algos import calc_apply_threshold, calc_force_send, calc_mask_pixels, calc_peakfinder_analysis, calc_radial_integration, calc_roi, calc_spi_analysis, JFData -from utils import BufferedJSON, read_bit +from utils import Aggregator, BufferedJSON, read_bit from zmqsocks import ZMQSockets @@ -43,9 +43,7 @@ def work(backend_address, accumulator_host, accumulator_port, visualisation_host zmq_socks = ZMQSockets(backend_address, accumulator_host, accumulator_port, visualisation_host, visualisation_port) - - data_summed = None - n_aggregated_images = 0 + aggregator = Aggregator() while True: @@ -118,7 +116,7 @@ def work(backend_address, accumulator_host, accumulator_port, visualisation_host calc_peakfinder_analysis(results, pfdata, pixel_mask_pf) # ??? - data, force_send_visualisation, data_summed, n_aggregated_images = calc_force_send(results, data, pixel_mask_pf, image, data_summed, n_aggregated_images) + data, force_send_visualisation, aggregator = calc_force_send(results, data, pixel_mask_pf, image, aggregator) results["type"] = str(data.dtype) results["shape"] = data.shape