diff --git a/dap/accumulator.py b/dap/accumulator.py index bd5a0fd..b30760d 100644 --- a/dap/accumulator.py +++ b/dap/accumulator.py @@ -1,7 +1,7 @@ import argparse from utils import FileHandler -from zmqsocks import ZMQSocketsAccumulator +from zmqsocks import ZMQSocketsAccumulator, make_address OUTPUT_DIR = "/gpfs/photonics/swissfel/buffer/dap/data" @@ -10,17 +10,19 @@ OUTPUT_DIR = "/gpfs/photonics/swissfel/buffer/dap/data" def main(): parser = argparse.ArgumentParser() - parser.add_argument("--accumulator_host", default="localhost") + parser.add_argument("--accumulator_host", default="*") parser.add_argument("--accumulator_port", type=int, default=13002) clargs = parser.parse_args() - accumulate(clargs.accumulator_host, clargs.accumulator_port) + accumulator_addr = make_address(clargs.accumulator_host, clargs.accumulator_port) + + accumulate(accumulator_addr) -def accumulate(accumulator_host, accumulator_port): - zmq_socks = ZMQSocketsAccumulator(accumulator_host, accumulator_port) +def accumulate(accumulator_addr): + zmq_socks = ZMQSocketsAccumulator(accumulator_addr) output = FileHandler() diff --git a/dap/worker.py b/dap/worker.py index aa61af0..b4342c9 100644 --- a/dap/worker.py +++ b/dap/worker.py @@ -7,7 +7,7 @@ from algos import ( calc_radial_integration, calc_roi, calc_spi_analysis, calc_streakfinder_analysis, JFData ) from utils import Aggregator, BufferedJSON, randskip, read_bit -from zmqsocks import ZMQSocketsWorker +from zmqsocks import ZMQSocketsWorker, make_address def main(): @@ -24,25 +24,26 @@ def main(): clargs = parser.parse_args() + backend_addr = make_address(clargs.backend_host, clargs.backend_port) + accumulator_addr = make_address(clargs.accumulator_host, clargs.accumulator_port) + visualisation_addr = make_address(clargs.visualisation_host, clargs.visualisation_port) + work( - clargs.backend_host, - clargs.backend_port, - clargs.accumulator_host, - clargs.accumulator_port, - clargs.visualisation_host, - clargs.visualisation_port, + backend_addr, + accumulator_addr, + visualisation_addr, clargs.peakfinder_parameters, clargs.skip_frames_rate ) -def work(backend_host, backend_port, accumulator_host, accumulator_port, visualisation_host, visualisation_port, fn_peakfinder_parameters, skip_frames_rate): +def work(backend_addr, accumulator_addr, visualisation_addr, fn_peakfinder_parameters, skip_frames_rate): bj_peakfinder_parameters = BufferedJSON(fn_peakfinder_parameters) jfdata = JFData() - zmq_socks = ZMQSocketsWorker(backend_host, backend_port, accumulator_host, accumulator_port, visualisation_host, visualisation_port) + zmq_socks = ZMQSocketsWorker(backend_addr, accumulator_addr, visualisation_addr) aggregator = Aggregator() diff --git a/dap/zmqsocks.py b/dap/zmqsocks.py index cd0d345..600fd98 100644 --- a/dap/zmqsocks.py +++ b/dap/zmqsocks.py @@ -7,13 +7,13 @@ FLAGS = 0 class ZMQSocketsAccumulator: - def __init__(self, _accumulator_host, accumulator_port): #TODO: accumulator_host is not used + def __init__(self, accumulator_addr): zmq_context = zmq.Context(io_threads=4) self.poller = poller = zmq.Poller() # receive from workers: self.accumulator_socket = accumulator_socket = zmq_context.socket(zmq.PULL) - accumulator_socket.bind(f"tcp://*:{accumulator_port}") + accumulator_socket.bind(accumulator_addr) poller.register(accumulator_socket, zmq.POLLIN) @@ -29,21 +29,21 @@ class ZMQSocketsAccumulator: class ZMQSocketsWorker: - def __init__(self, backend_host, backend_port, accumulator_host, accumulator_port, visualisation_host, visualisation_port): + def __init__(self, backend_addr, accumulator_addr, visualisation_addr): zmq_context = zmq.Context(io_threads=4) self.poller = poller = zmq.Poller() # receive from backend: self.backend_socket = backend_socket = zmq_context.socket(zmq.PULL) - backend_socket.connect(f"tcp://{backend_host}:{backend_port}") + backend_socket.connect(backend_addr) poller.register(backend_socket, zmq.POLLIN) self.accumulator_socket = accumulator_socket = zmq_context.socket(zmq.PUSH) - accumulator_socket.connect(f"tcp://{accumulator_host}:{accumulator_port}") + accumulator_socket.connect(accumulator_addr) self.visualisation_socket = visualisation_socket = zmq_context.socket(zmq.PUB) - visualisation_socket.connect(f"tcp://{visualisation_host}:{visualisation_port}") + visualisation_socket.connect(visualisation_addr) # in case of problem with communication to visualisation, keep in 0mq buffer only few messages visualisation_socket.set_hwm(10) @@ -69,3 +69,8 @@ class ZMQSocketsWorker: +def make_address(host, port) + return f"tcp://{host}:{port}" + + +