public release 3.0.0 - see README and CHANGES for details
This commit is contained in:
@ -4,16 +4,13 @@ calculation dispatcher.
|
||||
|
||||
@author Matthias Muntwiler
|
||||
|
||||
@copyright (c) 2015 by Paul Scherrer Institut @n
|
||||
@copyright (c) 2015-21 by Paul Scherrer Institut @n
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); @n
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import os.path
|
||||
import datetime
|
||||
@ -21,10 +18,20 @@ import signal
|
||||
import collections
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
|
||||
from attrdict import AttrDict
|
||||
from mpi4py import MPI
|
||||
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
mpi_comm = MPI.COMM_WORLD
|
||||
mpi_size = mpi_comm.Get_size()
|
||||
mpi_rank = mpi_comm.Get_rank()
|
||||
except ImportError:
|
||||
MPI = None
|
||||
mpi_comm = None
|
||||
mpi_size = 1
|
||||
mpi_rank = 0
|
||||
|
||||
from pmsco.helpers import BraceMessage as BMsg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -521,8 +528,7 @@ class MscoProcess(object):
|
||||
#
|
||||
# the default is 2 days after start.
|
||||
|
||||
def __init__(self, comm):
|
||||
self._comm = comm
|
||||
def __init__(self):
|
||||
self._project = None
|
||||
self._atomic_scattering = None
|
||||
self._multiple_scattering = None
|
||||
@ -829,12 +835,12 @@ class MscoMaster(MscoProcess):
|
||||
# the values are handlers.TaskHandler objects.
|
||||
# the objects can be accessed in attribute or dictionary notation.
|
||||
|
||||
def __init__(self, comm):
|
||||
super(MscoMaster, self).__init__(comm)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._pending_tasks = collections.OrderedDict()
|
||||
self._running_tasks = collections.OrderedDict()
|
||||
self._complete_tasks = collections.OrderedDict()
|
||||
self._slaves = self._comm.Get_size() - 1
|
||||
self._slaves = mpi_size - 1
|
||||
self._idle_ranks = []
|
||||
self.max_calculations = 1000000
|
||||
self._calculations = 0
|
||||
@ -879,8 +885,8 @@ class MscoMaster(MscoProcess):
|
||||
self._idle_ranks = list(range(1, self._running_slaves + 1))
|
||||
|
||||
self._root_task = CalculationTask()
|
||||
self._root_task.file_root = project.output_file
|
||||
self._root_task.model = project.create_model_space().start
|
||||
self._root_task.file_root = str(project.output_file)
|
||||
self._root_task.model = project.model_space.start
|
||||
|
||||
for level in self.task_levels:
|
||||
self.task_handlers[level] = project.handler_classes[level]()
|
||||
@ -1033,7 +1039,7 @@ class MscoMaster(MscoProcess):
|
||||
else:
|
||||
logger.debug("assigning task %s to rank %u", str(task.id), rank)
|
||||
self._running_tasks[task.id] = task
|
||||
self._comm.send(task.get_mpi_message(), dest=rank, tag=TAG_NEW_TASK)
|
||||
mpi_comm.send(task.get_mpi_message(), dest=rank, tag=TAG_NEW_TASK)
|
||||
self._calculations += 1
|
||||
else:
|
||||
if not self._finishing:
|
||||
@ -1055,7 +1061,7 @@ class MscoMaster(MscoProcess):
|
||||
while self._idle_ranks:
|
||||
rank = self._idle_ranks.pop()
|
||||
logger.debug("send finish tag to rank %u", rank)
|
||||
self._comm.send(None, dest=rank, tag=TAG_FINISH)
|
||||
mpi_comm.send(None, dest=rank, tag=TAG_FINISH)
|
||||
self._running_slaves -= 1
|
||||
|
||||
def _receive_result(self):
|
||||
@ -1065,7 +1071,7 @@ class MscoMaster(MscoProcess):
|
||||
if self._running_slaves > 0:
|
||||
logger.debug("waiting for calculation result")
|
||||
s = MPI.Status()
|
||||
data = self._comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=s)
|
||||
data = mpi_comm.recv(source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=s)
|
||||
|
||||
if s.tag == TAG_NEW_RESULT:
|
||||
task_id = self._accept_task_done(data)
|
||||
@ -1185,8 +1191,8 @@ class MscoSlave(MscoProcess):
|
||||
#
|
||||
# typically, a task is aborted when an exception is encountered.
|
||||
|
||||
def __init__(self, comm):
|
||||
super(MscoSlave, self).__init__(comm)
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._errors = 0
|
||||
self._max_errors = 5
|
||||
|
||||
@ -1199,7 +1205,7 @@ class MscoSlave(MscoProcess):
|
||||
self._running = True
|
||||
while self._running:
|
||||
logger.debug("waiting for message")
|
||||
data = self._comm.recv(source=0, tag=MPI.ANY_TAG, status=s)
|
||||
data = mpi_comm.recv(source=0, tag=MPI.ANY_TAG, status=s)
|
||||
if s.tag == TAG_NEW_TASK:
|
||||
logger.debug("received new task")
|
||||
self.accept_task(data)
|
||||
@ -1229,17 +1235,17 @@ class MscoSlave(MscoProcess):
|
||||
logger.exception(BMsg("unhandled exception in calculation task {0}", task.id))
|
||||
self._errors += 1
|
||||
if self._errors <= self._max_errors:
|
||||
self._comm.send(data, dest=0, tag=TAG_INVALID_RESULT)
|
||||
mpi_comm.send(data, dest=0, tag=TAG_INVALID_RESULT)
|
||||
else:
|
||||
logger.error("too many exceptions, aborting")
|
||||
self._running = False
|
||||
self._comm.send(data, dest=0, tag=TAG_ERROR_ABORTING)
|
||||
mpi_comm.send(data, dest=0, tag=TAG_ERROR_ABORTING)
|
||||
else:
|
||||
logger.debug(BMsg("sending result of task {0} to master", result.id))
|
||||
self._comm.send(result.get_mpi_message(), dest=0, tag=TAG_NEW_RESULT)
|
||||
mpi_comm.send(result.get_mpi_message(), dest=0, tag=TAG_NEW_RESULT)
|
||||
|
||||
|
||||
def run_master(mpi_comm, project):
|
||||
def run_master(project):
|
||||
"""
|
||||
initialize and run the master calculation loop.
|
||||
|
||||
@ -1251,25 +1257,25 @@ def run_master(mpi_comm, project):
|
||||
if an unhandled exception occurs, this function aborts the MPI communicator, killing all MPI processes.
|
||||
the caller will not have a chance to handle the exception.
|
||||
|
||||
@param mpi_comm: MPI communicator (mpi4py.MPI.COMM_WORLD).
|
||||
|
||||
@param project: project instance (sub-class of project.Project).
|
||||
"""
|
||||
try:
|
||||
master = MscoMaster(mpi_comm)
|
||||
master = MscoMaster()
|
||||
master.setup(project)
|
||||
master.run()
|
||||
master.cleanup()
|
||||
except (SystemExit, KeyboardInterrupt):
|
||||
mpi_comm.Abort()
|
||||
if mpi_comm:
|
||||
mpi_comm.Abort()
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("unhandled exception in master calculation loop.")
|
||||
mpi_comm.Abort()
|
||||
if mpi_comm:
|
||||
mpi_comm.Abort()
|
||||
raise
|
||||
|
||||
|
||||
def run_slave(mpi_comm, project):
|
||||
def run_slave(project):
|
||||
"""
|
||||
initialize and run the slave calculation loop.
|
||||
|
||||
@ -1282,12 +1288,10 @@ def run_slave(mpi_comm, project):
|
||||
unless it is a SystemExit or KeyboardInterrupt (where we expect that the master also receives the signal),
|
||||
the MPI communicator is aborted, killing all MPI processes.
|
||||
|
||||
@param mpi_comm: MPI communicator (mpi4py.MPI.COMM_WORLD).
|
||||
|
||||
@param project: project instance (sub-class of project.Project).
|
||||
"""
|
||||
try:
|
||||
slave = MscoSlave(mpi_comm)
|
||||
slave = MscoSlave()
|
||||
slave.setup(project)
|
||||
slave.run()
|
||||
slave.cleanup()
|
||||
@ -1295,7 +1299,8 @@ def run_slave(mpi_comm, project):
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("unhandled exception in slave calculation loop.")
|
||||
mpi_comm.Abort()
|
||||
if mpi_comm:
|
||||
mpi_comm.Abort()
|
||||
raise
|
||||
|
||||
|
||||
@ -1307,12 +1312,9 @@ def run_calculations(project):
|
||||
|
||||
@param project: project instance (sub-class of project.Project).
|
||||
"""
|
||||
mpi_comm = MPI.COMM_WORLD
|
||||
mpi_rank = mpi_comm.Get_rank()
|
||||
|
||||
if mpi_rank == 0:
|
||||
logger.debug("MPI rank %u setting up master loop", mpi_rank)
|
||||
run_master(mpi_comm, project)
|
||||
run_master(project)
|
||||
else:
|
||||
logger.debug("MPI rank %u setting up slave loop", mpi_rank)
|
||||
run_slave(mpi_comm, project)
|
||||
run_slave(project)
|
||||
|
Reference in New Issue
Block a user