diff --git a/dap/worker.py b/dap/worker.py index 9c7311e..eec10fa 100644 --- a/dap/worker.py +++ b/dap/worker.py @@ -7,10 +7,10 @@ from time import sleep import jungfrau_utils as ju import numpy as np -import zmq from peakfinder8_extension import peakfinder_8 from algos import calc_radial_integration, calc_apply_additional_mask +from zmqsocks import ZMQSockets FLAGS = 0 @@ -57,27 +57,11 @@ def work(backend_address, accumulator_host, accumulator_port, visualisation_host ju_stream_adapter = ju.StreamAdapter() - zmq_context = zmq.Context(io_threads=4) - poller = zmq.Poller() + zmq_socks = ZMQSockets(backend_address, accumulator_host, accumulator_port, visualisation_host, visualisation_port) # all the normal workers worker = 1 -# receive from backend: - backend_socket = zmq_context.socket(zmq.PULL) - backend_socket.connect(backend_address) - - poller.register(backend_socket, zmq.POLLIN) - - accumulator_socket = zmq_context.socket(zmq.PUSH) - accumulator_socket.connect(f"tcp://{accumulator_host}:{accumulator_port}") - - visualisation_socket = zmq_context.socket(zmq.PUB) - visualisation_socket.connect(f"tcp://{visualisation_host}:{visualisation_port}") - -# in case of problem with communication to visualisation, keep in 0mq buffer only few messages - visualisation_socket.set_hwm(10) - keep_pixels = None r_radial_integration = None center_radial_integration = None @@ -113,13 +97,10 @@ def work(backend_address, accumulator_host, accumulator_port, visualisation_host except Exception as e: print(f"({pulseid}) problem ({e}) to read peakfinder parameters file, worker : {worker}", flush=True) - events = dict(poller.poll(2000)) # check every 2 seconds in each worker - if backend_socket not in events: + if not zmq_socks.has_data(): continue - metadata = backend_socket.recv_json(FLAGS) - image = backend_socket.recv(FLAGS, copy=False, track=False) - image = np.frombuffer(image, dtype=metadata["type"]).reshape(metadata["shape"]) + image, metadata = zmq_socks.get_data() results = copy(metadata) if results["shape"][0] == 2 and results["shape"][1] == 2: @@ -360,7 +341,7 @@ def work(backend_address, accumulator_host, accumulator_port, visualisation_host results["shape"] = data.shape - accumulator_socket.send_json(results, FLAGS) + zmq_socks.send_accumulator(results) send_empty_cond1 = (apply_aggregation and "aggregation_max" in results and not forceSendVisualisation) @@ -371,8 +352,7 @@ def work(backend_address, accumulator_host, accumulator_port, visualisation_host results["type"] = str(data.dtype) results["shape"] = data.shape - visualisation_socket.send_json(results, FLAGS | zmq.SNDMORE) - visualisation_socket.send(data, FLAGS, copy=True, track=True) + zmq_socks.send_visualisation(results, data) diff --git a/dap/zmqsocks.py b/dap/zmqsocks.py new file mode 100644 index 0000000..be25cc8 --- /dev/null +++ b/dap/zmqsocks.py @@ -0,0 +1,49 @@ +import numpy as np +import zmq + + +FLAGS = 0 + + +class ZMQSockets: + + def __init__(self, backend_address, accumulator_host, accumulator_port, visualisation_host, visualisation_port): + zmq_context = zmq.Context(io_threads=4) + self.poller = poller = zmq.Poller() + + # receive from backend: + self.backend_socket = backend_socket = zmq_context.socket(zmq.PULL) + backend_socket.connect(backend_address) + + poller.register(backend_socket, zmq.POLLIN) + + self.accumulator_socket = accumulator_socket = zmq_context.socket(zmq.PUSH) + accumulator_socket.connect(f"tcp://{accumulator_host}:{accumulator_port}") + + self.visualisation_socket = visualisation_socket = zmq_context.socket(zmq.PUB) + visualisation_socket.connect(f"tcp://{visualisation_host}:{visualisation_port}") + + # in case of problem with communication to visualisation, keep in 0mq buffer only few messages + visualisation_socket.set_hwm(10) + + + def has_data(self): + events = dict(self.poller.poll(2000)) # check every 2 seconds in each worker + return (self.backend_socket in events) + + def get_data(self): + metadata = self.backend_socket.recv_json(FLAGS) + image = self.backend_socket.recv(FLAGS, copy=False, track=False) + image = np.frombuffer(image, dtype=metadata["type"]).reshape(metadata["shape"]) + return image, metadata + + + def send_accumulator(self, results): + self.accumulator_socket.send_json(results, FLAGS) + + def send_visualisation(self, results, data): + self.visualisation_socket.send_json(results, FLAGS | zmq.SNDMORE) + self.visualisation_socket.send(data, FLAGS, copy=True, track=True) + + +