Files
pmsco-public/pmsco/reports/swarm.py

437 lines
18 KiB
Python
Executable File

#!/usr/bin/env python
"""
@package pmsco.reports.swarm
graphics rendering module for swarm dynamics.
the module can be used in several different ways:
1. via the command line on a pmsco database or .dat results file.
this is the most simple but least flexible way.
2. via python functions on given population arrays or database queries.
this is the most flexible way but requires understanding of the required data formats.
3. as a listener on calculation events. (to be implemented)
this will be configurable in the run file.
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
@copyright (c) 2021 by Paul Scherrer Institut @n
Licensed under the Apache License, Version 2.0 (the "License"); @n
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
"""
import argparse
import copy
import itertools
import logging
import numpy as np
from pathlib import Path
import sys
if __name__ == "__main__":
pmsco_root = Path(__file__).resolve().parent.parent.parent
if str(pmsco_root) not in sys.path:
sys.path.insert(0, str(pmsco_root))
import pmsco.reports.results as rp_results
import pmsco.database.util as db_util
import pmsco.database.query as db_query
from pmsco.reports.base import ProjectReport
from pmsco.reports.population import GenerationTracker
logger = logging.getLogger(__name__)
try:
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
# from matplotlib.backends.backend_pdf import FigureCanvasPdf
# from matplotlib.backends.backend_svg import FigureCanvasSVG
except ImportError:
Figure = None
FigureCanvas = None
MaxNLocator = None
logger.warning("error importing matplotlib. graphics rendering disabled.")
def plot_swarm(filename, pos, vel, rfac, params, title=None, cmap=None, canvas=None):
"""
plot a two-dimensional particle swarm population.
the plot consists of three elements:
- a pseudo-color scatter plot of R-factors in the background,
- a scatter plot of particle positions, optionally colorized by R-factor.
- a quiver plot indicating the velocities of the particles.
@note the color of the particles is mapped according to the _rfac column of the positions array,
while the background plot uses R-factors from the _rfac column of the rfac array.
this is a low-level function containing just the plotting commands from numpy arrays.
the graphics file format can be changed by providing a specific canvas. default is PNG.
this function requires the matplotlib module.
if it is not available, the function raises an error.
@param filename: path and base name of the output file without extension.
a generation index and the file extension according to the file format are appended.
@param pos: structured ndarray containing the positions of the particles.
if the array contains an _rfac column, the dot is colorized.
@param vel: structured ndarray containing the velocities of the particles.
@param rfac: structured ndarray containing positions and R-factor values.
this array is independent of pos and vel.
it can also be set to None if results should be suppressed.
@param params: dictionary of two parameters to be plotted.
the keys correspond to columns of the pos, vel and rfac arrays.
the values are lists [minimum, maximum] that define the axis range.
the dictionary must contain at least two parameters and should not contain more than two.
only the first two are plotted.
@param title: (str) title of the chart.
default: derived from parameter names.
@param cmap: (str) name of colour map supported by matplotlib.
default is 'plasma'.
other good-looking options are 'viridis', 'plasma', 'inferno', 'magma', 'cividis'.
@param canvas: a FigureCanvas class reference from a matplotlib backend.
if None, the default FigureCanvasAgg is used which produces a bitmap file in PNG format.
some other options are:
matplotlib.backends.backend_pdf.FigureCanvasPdf or
matplotlib.backends.backend_svg.FigureCanvasSVG.
@return (str) path and name of the generated graphics file.
None if no file was generated due to an error.
"""
pnames = list(params.keys())
if canvas is None:
canvas = FigureCanvas
if canvas is None or Figure is None:
return None
if cmap is None:
cmap = 'plasma'
if title is None:
title = f'{pnames[0]} - {pnames[1]} swarm map'
fig = Figure()
canvas(fig)
ax = fig.add_subplot(111)
s = None
if rfac is not None:
try:
s = ax.scatter(rfac[pnames[0]], rfac[pnames[1]], s=5, c=rfac['_rfac'], cmap=cmap, vmin=0, vmax=1)
except ValueError:
# _rfac column missing
pass
ax.plot(pos[pnames[0]], pos[pnames[1]], 'co')
if '_rfac' in pos.dtype.names:
s = ax.scatter(pos[pnames[0]], pos[pnames[1]], s=5, c=pos['_rfac'], cmap=cmap, vmin=0, vmax=1)
if vel is not None:
ax.quiver(pos[pnames[0]], pos[pnames[1]], vel[pnames[0]], vel[pnames[1]], angles='xy', units='xy',
scale_units='xy', scale=1, color='c')
if s is not None:
cb = ax.figure.colorbar(s, ax=ax)
cb.ax.set_ylabel("R-factor", rotation=-90, va="bottom")
ax.set_xlim(params[pnames[0]])
ax.set_ylim(params[pnames[1]])
ax.set_xlabel(pnames[0])
ax.set_ylabel(pnames[1])
ax.set_title(title)
out_filename = "{base}.{ext}".format(base=filename, ext=canvas.get_default_filetype())
try:
fig.savefig(out_filename)
except OSError:
logger.exception(f"exception while saving figure {out_filename}")
out_filename = None
return out_filename
class SwarmPlot(ProjectReport, GenerationTracker):
"""
produce two-dimensional particle swarm population maps
this class collects and validates all parameters and data for generating a series of swarm plots.
it iterates over generations and parameter pairs and calls plot_swarm() for each combination.
the plots consist of two sub-plots on the same axes:
1. a scatter plot in the background maps the R-factor across the model space.
this can contain an arbitrary number of results from as many generations as requested.
2. a quiver plot shows the positions and velocities of the particles in one generation.
the two plots each have their independent data source.
it is up to the owner to select corresponding and meaningful data.
usage:
1. assign public attributes as necessary, leave defaults at None.
2. load data into result_data and swarm_data arrays by calling the appropriate load methods.
3. call validate()
4. call plot()
"""
## @var params
# parameter that should be plotted
#
# this should be a list of pairs (sequence or tuple of two strings) of parameter names.
# for each pair, a plot is generated.
# by default, all non-degenerate parameters are plotted in all combinations.
## @var result_data
# R-factor data for scatter plot (in the background)
#
# pmsco.reports.results.ResultData object holding the data or filter criteria
# for the scatter plot in the background.
# the scatter plot shows a map of the model space in terms of R-factor.
## @var swarm_data
# particle vectors for swarm (quiver) plot
#
# pmsco.reports.results.ResultData object holding the data or filter criteria
# for the quiver plot in the foreground.
# the quiver plot shows the positions and velocities of the particles.
def __init__(self):
super().__init__()
self._modes = ['swarm']
self.result_data = rp_results.ResultData()
self.swarm_data = rp_results.ResultData()
self.filename_format = "${base}-swarm-${param0}-${param1}-${gen}"
self.title_format = "${param0} ${param1} gen ${gen}"
self.cmap = None
self.params = []
def select_data(self, jobs=-1, calcs=None):
"""
query data from the database
this method must be implemented by the sub-class.
@param jobs: filter by job.
the argument can be a singleton or sequence of orm.Job objects or numeric id.
if None, results from all jobs are loaded.
if -1 (default), results from the most recent job (by datetime field) are loaded.
@param calcs: the calcs argument is ignored.
@return: None
"""
with self.get_session() as session:
if jobs == -1:
jobs = db_query.query_newest_job(session)
changed_gens = self.changed_generations(session, jobs)
self.result_data.reset_filters()
self.result_data.levels = {'scan': -1}
self.result_data.load_from_db(session, jobs=jobs)
self.swarm_data = copy.copy(self.result_data)
self.swarm_data.reset_filters()
self.swarm_data.generations = changed_gens
self.swarm_data.apply_filters()
self.swarm_data.update_collections()
if self._project:
self.result_data.set_model_space(self._project.model_space)
self.swarm_data.set_model_space(self._project.model_space)
def create_report(self):
"""
generate the plots based on the stored attributes.
this method essentially loops over generations and parameter combinations,
and compiles the input for plot_swarm.
combinations of parameter names are either set by the user in self.params
or constructed from all non-degenerate (not constant) parameters.
@return: list of created files
"""
# check that result data is compatible with swarm plots
if self.swarm_data.params is None or len(self.swarm_data.params) < 2:
logger.warning("result data must contain at least 2 parameters")
return []
if self.swarm_data.generations is None or len(self.swarm_data.generations) < 1:
logger.warning("result data must specify at least 1 generation")
return []
if self.swarm_data.particles is None or len(self.swarm_data.particles) < 5:
logger.warning("result data must specify at least 2 particles")
return []
vmin = self.swarm_data.model_space.min
vmax = self.swarm_data.model_space.max
nd_params = self.swarm_data.non_degenerate_params()
if self.params:
ppairs = [p for p in self.params if len(p) == 2 and p[0] in nd_params and p[1] in nd_params]
else:
pnames = sorted(list(nd_params), key=str.lower)
ppairs = [p for p in itertools.combinations(pnames, 2)]
kwargs = {}
if self.cmap is not None:
kwargs['cmap'] = self.cmap
if self.canvas is not None:
kwargs['canvas'] = self.canvas
files = []
fdict = {'base': self.base_filename}
for rd in self.swarm_data.iterate_generations():
fdict['gen'] = int(rd.generations[0])
for ppair in ppairs:
fdict['param0'] = ppair[0]
fdict['param1'] = ppair[1]
filename = Path(self.report_dir, self.filename_format)
filename = Path(self.resolve_template(filename, fdict))
kwargs['title'] = self.resolve_template(self.title_format, fdict)
params = {ppair[0]: [vmin[ppair[0]], vmax[ppair[0]]],
ppair[1]: [vmin[ppair[1]], vmax[ppair[1]]]}
of = plot_swarm(filename, rd.values, rd.deltas, self.result_data.values, params, **kwargs)
if of:
files.append(of)
return files
def render_swarm(output_file, values, deltas=None, model_space=None, generations=None, title=None,
cmap=None, canvas=None):
"""
render a two-dimensional particle swarm population.
this function generates a schematic rendering of a particle swarm in two dimensions.
particles are represented by their position and velocity, indicated by an arrow.
the model space is projected on the first two (or selected two) variable parameters.
in the background, a scatter plot of results (dots with pseudocolor representing the R-factor) can be plotted.
the chart type is designed for the particle swarm optimization algorithm.
the function requires input in one of the following forms:
- an open pmsco database session. the most recent job results are loaded.
- a result (.dat) file or numpy structured array.
the array must contain regular parameters, as well as the _particle and _gen columns,
and optionally, the _rfac column.
the function generates one chart per generation unless the generation argument is specified.
velocities are not plotted.
- position (.pos) and velocity (.vel) files or the respective numpy structured arrays.
the arrays must contain regular parameters, as well as the `_particle` column.
the result file must also contain an `_rfac` column.
files are loaded by numpy.genfromtxt.
- a pmsco.optimizers.population.Population object with valid data.
the generation is taken from the respective attribute and overrides the function argument.
the graphics file format can be changed by providing a specific canvas. default is PNG.
this function requires the matplotlib module.
if it is not available, the function raises an error.
@note velocities are not stored in result files and can't be plotted if result files are used as input.
to plot velocities, provide the database, .pos/.vel files or Population object as input.
@param output_file: path and base name of the output file without extension.
a generation index and the file extension according to the file format are appended.
@param values: position data to plot.
the input can be specified in several ways - see the main description for details.
@param deltas: velocity data to plot (optional).
the input can be specified in several ways - see the main description for details.
@param model_space: model space can be a pmsco.project.ModelSpace object,
any object that contains the same min and max attributes as pmsco.project.ModelSpace,
or a dictionary with to keys 'min' and 'max' that provides the corresponding ModelSpace dictionaries.
by default, the model space boundaries are derived from the input data.
if a model_space is specified, only the parameters listed in it are plotted.
@param title: (str) title of the chart.
the title is a {}-style format string, where {base} is the output file name and {gen} is the generation.
default: derived from file name.
@param generations: (int or sequence) generation index or list of indices.
this index is used in the output file name and for filtering input data by generation.
if the input data does not contain the generation, no filtering is applied.
by default, no filtering is applied, and one graph for each generation is produced.
@param cmap: (str) name of colour map supported by matplotlib.
default is 'plasma'.
other good-looking options are 'viridis', 'plasma', 'inferno', 'magma', 'cividis'.
@param canvas: a FigureCanvas class reference from a matplotlib backend.
if None, the default FigureCanvasAgg is used which produces a bitmap file in PNG format.
some other options are:
matplotlib.backends.backend_pdf.FigureCanvasPdf or
matplotlib.backends.backend_svg.FigureCanvasSVG.
@return (list of str) paths of the generated graphics files.
empty if an error occurred.
@raise TypeError if matplotlib is not available.
"""
data = rp_results.ResultData()
if isinstance(generations, int):
generations = (generations,)
data.generations = generations
data.levels = {'scan': -1}
data.load_any(values, deltas)
if model_space is not None:
data.set_model_space(model_space)
plot = SwarmPlot()
plot.canvas = canvas
plot.cmap = cmap
if title:
plot.title_format = title
else:
plot.title_format = "${gen}"
plot.report_dir = Path(output_file).parent
plot.filename_format = Path(output_file).name + "-${param0}-${param1}-${gen}"
plot.validate(None)
plot.result_data = data
plot.swarm_data = data
files = plot.create_report()
return files
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="""
swarm dynamics plot for multiple-scattering optimization results
this module operates on results or database files and produces one graphics file per generation.
database files contain the complete information for all plot types.
data from the most recent job stored in the database is used.
while .dat results files contain all data shown in genetic plots,
they lack the velocity information for swarm plots.
only particle positions are plotted in this case.
.tasks.dat files lack the generation and particle identification and should not be used.
note that the plot type is independent of the optimization mode.
it's possible to generate genetic plots from a particle swarm optimization and vice versa.
""")
parser.add_argument('results_file',
help="path to results file (.dat) or sqlite3 database file.")
parser.add_argument('output_file',
help="base name of output file. generation and extension will be appended.")
parser.add_argument('-t', '--title', default=None,
help='graph title. may contain {gen} as a placeholder for the generation number.')
args, unknown_args = parser.parse_known_args()
kwargs = {}
if args.title is not None:
kwargs['title'] = args.title
render_func = render_swarm
if db_util.is_sqlite3_file(args.results_file):
import pmsco.database.access as db_access
db = db_access.DatabaseAccess()
db.connect(args.results_file)
with db.session() as session:
render_func(args.output_file, session, **kwargs)
else:
render_func(args.output_file, args.results_file, **kwargs)
if __name__ == '__main__':
main()
sys.exit(0)