pmsco-public/pmsco/dispatch.py
matthias muntwiler bbd16d0f94 add files for public distribution
based on internal repository 0a462b6 2017-11-22 14:41:39 +0100
2017-11-22 14:55:20 +01:00

973 lines
34 KiB
Python

"""
@package pmsco.dispatch
calculation dispatcher.
@author Matthias Muntwiler
@copyright (c) 2015 by Paul Scherrer Institut @n
Licensed under the Apache License, Version 2.0 (the "License"); @n
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
"""
from __future__ import division
import os
import os.path
import datetime
import signal
import collections
import copy
import logging
from mpi4py import MPI
from helpers import BraceMessage as BMsg
logger = logging.getLogger(__name__)
# messages sent from master to slaves
## master sends new assignment
## the message is a dictionary of model parameters
TAG_NEW_TASK = 1
## master calls end of calculation
## the message is empty
TAG_FINISH = 2
# messages sent from slaves to master
## slave reports new result
## the message is a dictionary of model parameters and results
TAG_NEW_RESULT = 1
## slave confirms end of calculation
## currently not used
TAG_FINISHED = 2
## slave has encountered an error, result is invalid
## the message contains the original task message
TAG_INVALID_RESULT = 3
## slave has encountered an error and is aborting
## the message is empty
TAG_ERROR_ABORTING = 4
CalcID = collections.namedtuple('CalcID', ['model', 'scan', 'sym', 'emit', 'region'])
class CalculationTask(object):
"""
identifies a calculation task by index and model parameters.
given an object of this class, the project must be able to:
* produce calculation parameters,
* produce a cluster,
* gather results.
a calculation task is identified by:
@arg @c id.model structure number or iteration (handled by the mode module)
@arg @c id.scan scan number (handled by the project)
@arg @c id.sym symmetry number (handled by the project)
@arg @c id.emit emitter number (handled by the project)
@arg @c id.region region number (handled by the region handler)
specified members must be greater or equal to zero.
-1 is the wildcard which is used in parent tasks,
where, e.g., no specific symmetry is chosen.
the root task has the ID (-1, -1, -1, -1).
"""
## @var id (CalcID)
# named tuple CalcID containing the 4-part calculation task identifier.
## @var parent_id (CalcID)
# named tuple CalcID containing the task identifier of the parent task.
## @var model (dict)
# dictionary containing the model parameters of the task.
#
# this is typically initialized to the parameters of the parent task,
# and varied at the level where the task ID was produced.
## @var file_root (string)
# file name without extension and index.
## @var file_ext (string)
# file name extension including dot.
#
# the extension is set by the scattering code interface.
# it must be passed back up the hierarchy.
## @var result_filename (string)
# name of the ETPI or ETPAI file that contains the result (intensity) data.
#
# this member is filled at the end of the calculation by MscoProcess.calc().
# the filename can be constructed given the base name, task ID, and extension.
# since this may be tedious, the filename must be returned here.
## @var modf_filename (string)
# name of the ETPI or ETPAI file that contains the resulting modulation function.
## @var time (timedelta)
# execution time of the task.
#
# execution time is measured as wall time of a single calculation.
# in parent tasks, execution time is the sum of the children's execution time.
#
# this information may be used to plan the end of the program run or for statistics.
## @var files (dict)
# files generated by the task and their category
#
# dictionary key is the file name,
# value is the file category, e.g. 'cluster', 'phase', etc.
#
# this information is used to automatically clean up unnecessary data files.
## @var region (dict)
# scan positions to substitute the ones from the original scan.
#
# this is used to distribute scans over multiple calculator processes,
# cf. e.g. @ref EnergyRegionHandler.
#
# dictionary key must be the scan dimension 'e', 't', 'p', 'a'.
# the value is a numpy.ndarray containing the scan positions.
#
# the dictionary can be empty if the original scan shall be calculated at once.
def __init__(self):
"""
create a new calculation task instance with all members equal to zero (root task).
"""
self.id = CalcID(-1, -1, -1, -1, -1)
self.parent_id = self.id
self.model = {}
self.file_root = ""
self.file_ext = ""
self.result_filename = ""
self.modf_filename = ""
self.result_valid = False
self.time = datetime.timedelta()
self.files = {}
self.region = {}
def __eq__(self, other):
"""
consider two tasks equal if they have the same ID.
EXPERIMENTAL
not clear whether this is a good idea.
we want this equality because the calculation may modify a task to return results.
yet, it should be considered the same task.
e.g., we want to find the task in the original task list.
"""
return isinstance(other, self.__class__) and self.id == other.id
def __hash__(self):
"""
the hash depends on the ID only.
"""
return hash(self.id)
def get_mpi_message(self):
"""
convert the task data to a format suitable for an MPI message.
mpi4py does not properly pickle objects.
we need to convert our data to basic types.
@return: (dict)
"""
msg = vars(self)
msg['id'] = self.id._asdict()
msg['parent_id'] = self.parent_id._asdict()
return msg
def set_mpi_message(self, msg):
"""
set object attributes from MPI message.
@param msg: message created by get_mpi_message()
@return: None
"""
if isinstance(msg['id'], dict):
msg['id'] = CalcID(**msg['id'])
if isinstance(msg['parent_id'], dict):
msg['parent_id'] = CalcID(**msg['parent_id'])
for k, v in msg.iteritems():
self.__setattr__(k, v)
def format_filename(self, **overrides):
"""
format input or output file name including calculation index.
@param overrides optional keyword arguments override object fields.
the following keywords are handled: @c root, @c model, @c scan, @c sym, @c emit, @c region, @c ext.
@return a string consisting of the concatenation of the base name, the ID, and the extension.
"""
parts = {}
parts['root'] = self.file_root
parts['model'] = self.id.model
parts['scan'] = self.id.scan
parts['sym'] = self.id.sym
parts['emit'] = self.id.emit
parts['region'] = self.id.region
parts['ext'] = self.file_ext
for key in overrides.keys():
parts[key] = overrides[key]
filename = "{root}_{model}_{scan}_{sym}_{emit}_{region}{ext}".format(**parts)
return filename
def copy(self):
"""
create a copy of the task.
@return: new independent CalculationTask with the same attributes as the original one.
"""
return copy.deepcopy(self)
def change_id(self, **kwargs):
"""
change the ID of the task.
@param kwargs: keyword arguments to change specific parts of the ID.
@note instead of changing all parts of the ID, you may simply assign a new CalcID to the id member.
"""
self.id = self.id._replace(**kwargs)
def add_task_file(self, name, category):
"""
register a file that was generated by the calculation task.
this information is used to automatically clean up unnecessary data files.
@param name: file name (optionally including a path).
@param category: file category, e.g. 'cluster', 'phase', etc.
@return: None
"""
self.files[name] = category
def rename_task_file(self, old_filename, new_filename):
"""
rename a file.
update the file list after a file was renamed.
the method silently ignores if old_filename is not listed.
@param old_filename: old file name
@param new_filename: new file name
@return: None
"""
try:
self.files[new_filename] = self.files[old_filename]
del self.files[old_filename]
except KeyError:
logger.warning("CalculationTask.rename_task_file: could not rename file {0} to {1}".format(old_filename,
new_filename))
def remove_task_file(self, filename):
"""
remove a file from the list of generated data files.
the method silently ignores if filename is not listed.
the method removes the file from the internal list.
it does not delete the file.
@param filename: file name
@return: None
"""
try:
del self.files[filename]
except KeyError:
logger.warning("CalculationTask.remove_task_file: could not remove file {0}".format(filename))
class MscoProcess(object):
"""
code shared by MscoMaster and MscoSlave.
mainly passing project parameters, handling OS signals,
calling an MSC calculation.
"""
## @var _finishing
# if True, the task loop should not accept new tasks.
#
# the loop still waits for the results of running calculations.
## @var _running
# while True, the task loop keeps running.
#
# if False, the loop will exit just before the next iteration.
# pending tasks and running calculations will not be waited for.
#
# @attention maks sure that all calculations are finished before resetting this flag.
# higher ranked processes may not exit if they do not receive the finish message.
## @var datetime_limit (datetime.datetime)
# date and time when the calculations should finish (regardless of result)
# because the process may get killed by the scheduler after this time.
#
# the default is 2 days after start.
def __init__(self, comm):
self._comm = comm
self._project = None
self._calculator = None
self._running = False
self._finishing = False
self.stop_signal = False
self.datetime_limit = datetime.datetime.now() + datetime.timedelta(days=2)
def setup(self, project):
self._project = project
self._calculator = project.calculator_class()
self._running = False
self._finishing = False
self.stop_signal = False
try:
# signal handlers
signal.signal(signal.SIGTERM, self.receive_signal)
signal.signal(signal.SIGUSR1, self.receive_signal)
signal.signal(signal.SIGUSR2, self.receive_signal)
except AttributeError:
pass
except ValueError:
pass
if project.timedelta_limit:
self.datetime_limit = datetime.datetime.now() + project.timedelta_limit
# noinspection PyUnusedLocal
def receive_signal(self, signum, stack):
"""
sets the self.stop_signal flag,
which will terminate the optimization process
as soon as all slaves have finished their calculation.
"""
self.stop_signal = True
def run(self):
pass
def cleanup(self):
"""
clean up after all calculations.
this method calls the clean up function of the project.
@return: None
"""
self._project.cleanup()
def calc(self, task):
"""
execute a single calculation.
* create the cluster and parameter objects.
* export the cluster for reference.
* choose the scan file.
* specify the output file name.
* call the calculation program.
* set task.result_filename, task.file_ext, task.time.
the function checks for some obvious errors, and skips the calculation if an error is detected, such as:
* missing atoms or emitters in the cluster.
@param task (CalculationTask) calculation task and identifier.
"""
s_model = str(task.model)
s_id = str(task.id)
logger.info("calling calculation %s", s_id)
logger.info("model %s", s_model)
start_time = datetime.datetime.now()
# create parameter and cluster structures
clu = self._project.cluster_generator.create_cluster(task.model, task.id)
par = self._project.create_params(task.model, task.id)
# generate file names
output_file = task.format_filename(ext="")
# determine scan range
scan = self._project.scans[task.id.scan]
if task.region:
scan = scan.copy()
try:
scan.energies = task.region['e']
logger.debug(BMsg("substitute energy region"))
except KeyError:
pass
try:
scan.thetas = task.region['t']
logger.debug(BMsg("substitute theta region"))
except KeyError:
pass
try:
scan.phis = task.region['p']
logger.debug(BMsg("substitute phi region"))
except KeyError:
pass
try:
scan.alphas = task.region['a']
logger.debug(BMsg("substitute alpha region"))
except KeyError:
pass
# check parameters and call the msc program
if clu.get_atom_count() < 2:
logger.error("empty cluster in calculation %s", s_id)
task.result_valid = False
elif clu.get_emitter_count() < 1:
logger.error("no emitters in cluster of calculation %s.", s_id)
task.result_valid = False
else:
files = self._calculator.check_cluster(clu, output_file)
task.files.update(files)
task.result_filename, files = self._calculator.run(par, clu, scan, output_file)
(root, ext) = os.path.splitext(task.result_filename)
task.file_ext = ext
task.result_valid = True
task.files.update(files)
task.time = datetime.datetime.now() - start_time
return task
class MscoMaster(MscoProcess):
"""
MscoMaster process for MSC calculations.
This class implements the main loop of the master (rank 0) process.
It sends calculation commands to the slaves, and dispatches the results
to the appropriate post-processing modules.
if there is only one process, the MscoMaster executes the calculations sequentially.
"""
## @var _pending_tasks (OrderedDict)
# CalculationTask objects of pending calculations.
# the dictionary keys are the task IDs.
## @var _running_tasks
# CalculationTask objects of currently running calculations.
# the dictionary keys are the task IDs.
## @var _complete_tasks
# CalculationTask objects of complete calculations.
#
# calculations are removed from the list when they are passed to the result handlers.
# the dictionary keys are the task IDs.
## @var _slaves
# total number of MPI slave ranks = number of calculator slots
## @var _idle_ranks
# list of ranks which are waiting to receive a task.
#
# list of int, default = []
## @var max_calculations
# maximum number of calculations
#
# if this limit is exceeded, the optimization will stop.
# the limit is meant to catch irregular situations such as run-time calculation errors or infinite loops.
## @var _calculations
# number of dispatched calculations
#
# if this number exceeds the @ref max_calculations, the optimization will stop.
## @var _running_slaves
# number of running slave ranks
#
# keeps track of active (idle or busy) slave ranks.
# it is used to make sure (if possible) that all slave tasks have finished before the master quits.
# the number is decremented when a slave quits due to an error or when the master sends a finish message.
## @var _min_queue_len
# if the queue length drops below this number, the dispatcher asks for the next round of tasks.
## @var _model_done
# (bool) True if the model handler did returned an empty list of new tasks.
## @var _root_task
# (CalculationTask) root calculation task
#
# this is the root of the calculation tasks tree.
# it defines the initial model and the output file name.
# it is passed to the model handler during the main loop.
# @var _model_handler
# (ModelHandler) model handler instance
# @var _scan_handler
# (ScanHandler) scan handler instance
# @var _symmetry_handler
# (SymmetryHandler) symmetry handler instance
# @var _emitter_handler
# (EmitterHandler) emitter handler instance
# @var _region_handler
# (RegionHandler) region handler instance
def __init__(self, comm):
super(MscoMaster, self).__init__(comm)
self._pending_tasks = collections.OrderedDict()
self._running_tasks = collections.OrderedDict()
self._complete_tasks = collections.OrderedDict()
self._slaves = self._comm.Get_size() - 1
self._idle_ranks = []
self.max_calculations = 1000000
self._calculations = 0
self._running_slaves = 0
self._model_done = False
self._min_queue_len = self._slaves + 1
self._root_task = None
self._model_handler = None
self._scan_handler = None
self._symmetry_handler = None
self._emitter_handler = None
self._region_handler = None
def setup(self, project):
"""
initialize the process, handlers, root task, slave counting.
this method initializes the run-time attributes of the master process,
particularly the attributes that depend on the project.
it creates the root calculation task with the initial model defined by the project.
it creates and initializes the task handler objects according to the handler classes defined by the project.
the method notifies the handlers of the number of available slave processes (slots).
some of the tasks handlers adjust their branching according to the number of slots.
this mechanism may be used to balance the load between the task levels.
however, the current implementation is very coarse in this respect.
it advertises all slots to the model handler but a reduced number to the remaining handlers
depending on the operation mode.
the region handler receives a maximum of 4 slots except in single calculation mode.
in single calculation mode, all slots can be used by all handlers.
"""
super(MscoMaster, self).setup(project)
logger.debug("master entering setup")
self._running_slaves = self._slaves
self._idle_ranks = range(1, self._running_slaves + 1)
self._root_task = CalculationTask()
self._root_task.file_root = project.output_file
self._root_task.model = project.create_domain().start
self._model_handler = project.handler_classes['model']()
self._scan_handler = project.handler_classes['scan']()
self._symmetry_handler = project.handler_classes['symmetry']()
self._emitter_handler = project.handler_classes['emitter']()
self._region_handler = project.handler_classes['region']()
self._model_handler.datetime_limit = self.datetime_limit
slaves_adj = max(self._slaves, 1)
self._model_handler.setup(project, slaves_adj)
if project.mode != "single":
slaves_adj = max(slaves_adj / 2, 1)
self._scan_handler.setup(project, slaves_adj)
self._symmetry_handler.setup(project, slaves_adj)
self._emitter_handler.setup(project, slaves_adj)
if project.mode != "single":
slaves_adj = min(slaves_adj, 4)
self._region_handler.setup(project, slaves_adj)
def run(self):
"""
main loop.
calls slaves, accept and dispatches results.
setup() must be called before, cleanup() after.
"""
self._running = True
self._calculations = 0
logger.debug("master entering main loop")
# main task loop
while self._running:
logger.debug("new iteration of master main loop")
self._create_tasks()
self._dispatch_results()
if self._finishing:
self._dispatch_finish()
else:
self._dispatch_tasks()
self._receive_result()
self._check_finish()
logger.debug("master exiting main loop")
self._running = False
def cleanup(self):
logger.debug("master entering cleanup")
self._region_handler.cleanup()
self._emitter_handler.cleanup()
self._symmetry_handler.cleanup()
self._scan_handler.cleanup()
self._model_handler.cleanup()
super(MscoMaster, self).cleanup()
def _dispatch_results(self):
"""
pass results through the post-processing modules.
"""
logger.debug("dispatching results of %u tasks", len(self._complete_tasks))
while self._complete_tasks:
__, task = self._complete_tasks.popitem(last=False)
logger.debug("passing task %s to region handler", str(task.id))
task = self._region_handler.add_result(task)
if task:
logger.debug("passing task %s to emitter handler", str(task.id))
task = self._emitter_handler.add_result(task)
if task:
logger.debug("passing task %s to symmetry handler", str(task.id))
task = self._symmetry_handler.add_result(task)
if task:
logger.debug("passing task %s to scan handler", str(task.id))
task = self._scan_handler.add_result(task)
if task:
logger.debug("passing task %s to model handler", str(task.id))
task = self._model_handler.add_result(task)
if task:
logger.debug("root task %s complete", str(task.id))
self._finishing = True
def _create_tasks(self):
"""
have the model handler generate the next round of top-level calculation tasks.
the method calls the model handler repeatedly
until the pending tasks queue is filled up
to more than the minimum queue length.
@return: None
"""
logger.debug("creating new tasks from root")
while len(self._pending_tasks) < self._min_queue_len:
tasks = self._model_handler.create_tasks(self._root_task)
logger.debug("model handler returned %u new tasks", len(tasks))
if not tasks:
self._model_done = True
break
for task in tasks:
self.add_model_task(task)
def _dispatch_tasks(self):
"""
send pending tasks to available slaves or master.
if there is only one process, the master executes one task, and returns.
"""
logger.debug("dispatching tasks to calculators")
if self._slaves > 0:
while not self._finishing:
try:
rank = self._idle_ranks.pop(0)
except IndexError:
break
try:
__, task = self._pending_tasks.popitem(last=False)
except KeyError:
self._idle_ranks.append(rank)
break
else:
logger.debug("assigning task %s to rank %u", str(task.id), rank)
self._running_tasks[task.id] = task
self._comm.send(task.get_mpi_message(), dest=rank, tag=TAG_NEW_TASK)
self._calculations += 1
else:
if not self._finishing:
try:
__, task = self._pending_tasks.popitem(last=False)
except KeyError:
pass
else:
logger.debug("executing task %s in master process", str(task.id))
self.calc(task)
self._calculations += 1
self._complete_tasks[task.id] = task
def _dispatch_finish(self):
"""
send all slave ranks a finish message.
"""
logger.debug("dispatch finish message to %u slaves", len(self._idle_ranks))
while self._idle_ranks:
rank = self._idle_ranks.pop()
logger.debug("send finish tag to rank %u", rank)
self._comm.send(None, dest=rank, tag=TAG_FINISH)
self._running_slaves -= 1
def _receive_result(self):
"""
wait for a message from another rank and process it.
"""
if self._running_slaves > 0:
logger.debug("waiting for calculation result")
s = MPI.Status()
data = self._comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=s)
if s.tag == TAG_NEW_RESULT:
task_id = self._accept_task_done(data)
self._idle_ranks.append(s.source)
logger.debug(BMsg("received result of task {0} from rank {1}", task_id, s.source))
elif s.tag == TAG_INVALID_RESULT:
task_id = self._accept_task_done(data)
self._idle_ranks.append(s.source)
logger.error(BMsg("received invalid result of task {0} from rank {1}", task_id, s.source))
elif s.tag == TAG_ERROR_ABORTING:
self._finishing = True
self._running_slaves -= 1
task_id = self._accept_task_done(data)
logger.error(BMsg("received abort signal from rank {1}", task_id, s.source))
def _accept_task_done(self, data):
"""
check the return message from a slave process and mark the task done.
if the message contains complete data of a running task, the corresponding CalculationTask object is returned.
@param data: a dictionary that can be imported into a CalculationTask object by the set_mpi_message() method.
@return: task ID (CalcID type) if the message contains the complete identification of a pending task,
None if the ID cannot be determined or is not in the list of running tasks.
"""
try:
task = CalculationTask()
task.set_mpi_message(data)
del self._running_tasks[task.id]
self._complete_tasks[task.id] = task
task_id = task.id
except (TypeError, IndexError, KeyError):
task_id = None
return task_id
def _check_finish(self):
"""
check whether the task loop is finished.
the task loop is finished on any of the following conditions:
* there are no pending or running tasks,
* a file named "finish_pmsco" exists in the working directory,
* a SIGUSR1, SIGUSR2, or SIGTERM signal was received,
* self.datetime_limit is exceeded, or
* self.max_calculations is exceeded.
self._finishing is set if any of these conditions is fulfilled.
self._running is reset if self._finishing is set and no calculation tasks are running.
@return: self._finishing
"""
if not self._finishing and (self._model_done and not self._pending_tasks and not self._running_tasks):
logger.info("finish: model handler is done")
self._finishing = True
if not self._finishing and (self._calculations >= self.max_calculations):
logger.warning("finish: max. calculations (%u) exeeded", self.max_calculations)
self._finishing = True
if not self._finishing and self.stop_signal:
logger.info("finish: stop signal received")
self._finishing = True
if not self._finishing and (datetime.datetime.now() > self.datetime_limit):
logger.warning("finish: time limit exceeded")
self._finishing = True
if not self._finishing and os.path.isfile("finish_pmsco"):
logger.info("finish: finish_pmsco file detected")
self._finishing = True
if self._finishing and not self._running_slaves and not self._running_tasks:
logger.info("finish: all calculations finished")
self._running = False
return self._finishing
def add_model_task(self, task):
"""
add a new model task including all of its children to the task queue.
@param task (CalculationTask) task identifier and model parameters.
"""
scan_tasks = self._scan_handler.create_tasks(task)
for scan_task in scan_tasks:
sym_tasks = self._symmetry_handler.create_tasks(scan_task)
for sym_task in sym_tasks:
emitter_tasks = self._emitter_handler.create_tasks(sym_task)
for emitter_task in emitter_tasks:
region_tasks = self._region_handler.create_tasks(emitter_task)
for region_task in region_tasks:
self._pending_tasks[region_task.id] = region_task
class MscoSlave(MscoProcess):
"""
MscoSlave process for MSC calculations.
This class implements the main loop of a slave (rank > 0) process.
It waits for assignments from the master process,
and runs one calculation after the other.
"""
## @var _errors
# number of errors (exceptions) encountered in calculation tasks.
#
# typically, a task is aborted when an exception is encountered.
def __init__(self, comm):
super(MscoSlave, self).__init__(comm)
self._errors = 0
self._max_errors = 5
def run(self):
"""
Waits for messages from the master and dispatches tasks.
"""
logger.debug("slave entering main loop")
s = MPI.Status()
self._running = True
while self._running:
logger.debug("waiting for message")
data = self._comm.recv(source=0, tag=MPI.ANY_TAG, status=s)
if s.tag == TAG_NEW_TASK:
logger.debug("received new task")
self.accept_task(data)
elif s.tag == TAG_FINISH:
logger.debug("received finish message")
self._running = False
logger.debug("slave exiting main loop")
def accept_task(self, data):
"""
Executes a calculation task and returns the result to the master.
if a recoverable exception (math, value and key errors) occurs,
the method catches the exception but sends a failure message to the master.
if exceptions occur repeatedly, the slave aborts and sends an abort message to the master.
@param data: task message received from MPI.
"""
task = CalculationTask()
task.set_mpi_message(data)
logger.debug(BMsg("executing task {0} in slave process", task.id))
try:
result = self.calc(task)
self._errors = 0
except (ValueError, ArithmeticError, LookupError):
logger.exception(BMsg("unhandled exception in calculation task {0}", task.id))
self._errors += 1
if self._errors <= self._max_errors:
self._comm.send(data, dest=0, tag=TAG_INVALID_RESULT)
else:
logger.error("too many exceptions, aborting")
self._running = False
self._comm.send(data, dest=0, tag=TAG_ERROR_ABORTING)
else:
logger.debug(BMsg("sending result of task {0} to master", result.id))
self._comm.send(result.get_mpi_message(), dest=0, tag=TAG_NEW_RESULT)
def run_master(mpi_comm, project):
"""
initialize and run the master calculation loop.
a MscoMaster object is created.
the MscoMaster executes the calculation loop and dispatches the tasks.
this function must be called in the MPI rank 0 process only.
if an unhandled exception occurs, this function aborts the MPI communicator, killing all MPI processes.
the caller will not have a chance to handle the exception.
@param mpi_comm: MPI communicator (mpi4py.MPI.COMM_WORLD).
@param project: project instance (sub-class of project.Project).
"""
try:
master = MscoMaster(mpi_comm)
master.setup(project)
master.run()
master.cleanup()
except (SystemExit, KeyboardInterrupt):
mpi_comm.Abort()
raise
except Exception:
logger.exception("unhandled exception in master calculation loop.")
mpi_comm.Abort()
raise
def run_slave(mpi_comm, project):
"""
initialize and run the slave calculation loop.
a MscoSlave object is created.
the MscoSlave accepts tasks from rank 0 and runs the calculations.
this function must be called in MPI rank > 0 processes.
if an unhandled exception occurs, the slave process terminates.
unless it is a SystemExit or KeyboardInterrupt (where we expect that the master also receives the signal),
the MPI communicator is aborted, killing all MPI processes.
@param mpi_comm: MPI communicator (mpi4py.MPI.COMM_WORLD).
@param project: project instance (sub-class of project.Project).
"""
try:
slave = MscoSlave(mpi_comm)
slave.setup(project)
slave.run()
slave.cleanup()
except (SystemExit, KeyboardInterrupt):
raise
except Exception:
logger.exception("unhandled exception in slave calculation loop.")
mpi_comm.Abort()
raise
def run_calculations(project):
"""
initialize and run the main calculation loop.
depending on the MPI rank, the function branches into run_master() (rank 0) or run_slave() (rank > 0).
@param project: project instance (sub-class of project.Project).
"""
mpi_comm = MPI.COMM_WORLD
mpi_rank = mpi_comm.Get_rank()
if mpi_rank == 0:
logger.debug("MPI rank %u setting up master loop", mpi_rank)
run_master(mpi_comm, project)
else:
logger.debug("MPI rank %u setting up slave loop", mpi_rank)
run_slave(mpi_comm, project)