#!/usr/bin/env python """ @package pmsco.reports.swarm graphics rendering module for swarm dynamics. the module can be used in several different ways: 1. via the command line on a pmsco database or .dat results file. this is the most simple but least flexible way. 2. via python functions on given population arrays or database queries. this is the most flexible way but requires understanding of the required data formats. 3. as a listener on calculation events. (to be implemented) this will be configurable in the run file. @author Matthias Muntwiler, matthias.muntwiler@psi.ch @copyright (c) 2021 by Paul Scherrer Institut @n Licensed under the Apache License, Version 2.0 (the "License"); @n you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 """ import argparse import copy import itertools import logging import numpy as np from pathlib import Path import sys if __name__ == "__main__": pmsco_root = Path(__file__).resolve().parent.parent.parent if str(pmsco_root) not in sys.path: sys.path.insert(0, str(pmsco_root)) import pmsco.reports.results as rp_results import pmsco.database.util as db_util import pmsco.database.query as db_query from pmsco.reports.base import ProjectReport from pmsco.reports.population import GenerationTracker logger = logging.getLogger(__name__) try: from matplotlib.figure import Figure from matplotlib.ticker import MaxNLocator from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas # from matplotlib.backends.backend_pdf import FigureCanvasPdf # from matplotlib.backends.backend_svg import FigureCanvasSVG except ImportError: Figure = None FigureCanvas = None MaxNLocator = None logger.warning("error importing matplotlib. graphics rendering disabled.") def plot_swarm(filename, pos, vel, rfac, params, title=None, cmap=None, canvas=None): """ plot a two-dimensional particle swarm population. the plot consists of three elements: - a pseudo-color scatter plot of R-factors in the background, - a scatter plot of particle positions, optionally colorized by R-factor. - a quiver plot indicating the velocities of the particles. @note the color of the particles is mapped according to the _rfac column of the positions array, while the background plot uses R-factors from the _rfac column of the rfac array. this is a low-level function containing just the plotting commands from numpy arrays. the graphics file format can be changed by providing a specific canvas. default is PNG. this function requires the matplotlib module. if it is not available, the function raises an error. @param filename: path and base name of the output file without extension. a generation index and the file extension according to the file format are appended. @param pos: structured ndarray containing the positions of the particles. if the array contains an _rfac column, the dot is colorized. @param vel: structured ndarray containing the velocities of the particles. @param rfac: structured ndarray containing positions and R-factor values. this array is independent of pos and vel. it can also be set to None if results should be suppressed. @param params: dictionary of two parameters to be plotted. the keys correspond to columns of the pos, vel and rfac arrays. the values are lists [minimum, maximum] that define the axis range. the dictionary must contain at least two parameters and should not contain more than two. only the first two are plotted. @param title: (str) title of the chart. default: derived from parameter names. @param cmap: (str) name of colour map supported by matplotlib. default is 'plasma'. other good-looking options are 'viridis', 'plasma', 'inferno', 'magma', 'cividis'. @param canvas: a FigureCanvas class reference from a matplotlib backend. if None, the default FigureCanvasAgg is used which produces a bitmap file in PNG format. some other options are: matplotlib.backends.backend_pdf.FigureCanvasPdf or matplotlib.backends.backend_svg.FigureCanvasSVG. @return (str) path and name of the generated graphics file. None if no file was generated due to an error. """ pnames = list(params.keys()) if canvas is None: canvas = FigureCanvas if canvas is None or Figure is None: return None if cmap is None: cmap = 'plasma' if title is None: title = f'{pnames[0]} - {pnames[1]} swarm map' fig = Figure() canvas(fig) ax = fig.add_subplot(111) s = None if rfac is not None: try: s = ax.scatter(rfac[pnames[0]], rfac[pnames[1]], s=5, c=rfac['_rfac'], cmap=cmap, vmin=0, vmax=1) except ValueError: # _rfac column missing pass ax.plot(pos[pnames[0]], pos[pnames[1]], 'co') if '_rfac' in pos.dtype.names: s = ax.scatter(pos[pnames[0]], pos[pnames[1]], s=5, c=pos['_rfac'], cmap=cmap, vmin=0, vmax=1) if vel is not None: ax.quiver(pos[pnames[0]], pos[pnames[1]], vel[pnames[0]], vel[pnames[1]], angles='xy', units='xy', scale_units='xy', scale=1, color='c') if s is not None: cb = ax.figure.colorbar(s, ax=ax) cb.ax.set_ylabel("R-factor", rotation=-90, va="bottom") ax.set_xlim(params[pnames[0]]) ax.set_ylim(params[pnames[1]]) ax.set_xlabel(pnames[0]) ax.set_ylabel(pnames[1]) ax.set_title(title) out_filename = "{base}.{ext}".format(base=filename, ext=canvas.get_default_filetype()) try: fig.savefig(out_filename) except OSError: logger.exception(f"exception while saving figure {out_filename}") out_filename = None return out_filename class SwarmPlot(ProjectReport, GenerationTracker): """ produce two-dimensional particle swarm population maps this class collects and validates all parameters and data for generating a series of swarm plots. it iterates over generations and parameter pairs and calls plot_swarm() for each combination. the plots consist of two sub-plots on the same axes: 1. a scatter plot in the background maps the R-factor across the model space. this can contain an arbitrary number of results from as many generations as requested. 2. a quiver plot shows the positions and velocities of the particles in one generation. the two plots each have their independent data source. it is up to the owner to select corresponding and meaningful data. usage: 1. assign public attributes as necessary, leave defaults at None. 2. load data into result_data and swarm_data arrays by calling the appropriate load methods. 3. call validate() 4. call plot() """ ## @var params # parameter that should be plotted # # this should be a list of pairs (sequence or tuple of two strings) of parameter names. # for each pair, a plot is generated. # by default, all non-degenerate parameters are plotted in all combinations. ## @var result_data # R-factor data for scatter plot (in the background) # # pmsco.reports.results.ResultData object holding the data or filter criteria # for the scatter plot in the background. # the scatter plot shows a map of the model space in terms of R-factor. ## @var swarm_data # particle vectors for swarm (quiver) plot # # pmsco.reports.results.ResultData object holding the data or filter criteria # for the quiver plot in the foreground. # the quiver plot shows the positions and velocities of the particles. def __init__(self): super().__init__() self._modes = ['swarm'] self.result_data = rp_results.ResultData() self.swarm_data = rp_results.ResultData() self.filename_format = "${base}-swarm-${param0}-${param1}-${gen}" self.title_format = "${param0} ${param1} gen ${gen}" self.cmap = None self.params = [] def select_data(self, jobs=-1, calcs=None): """ query data from the database this method must be implemented by the sub-class. @param jobs: filter by job. the argument can be a singleton or sequence of orm.Job objects or numeric id. if None, results from all jobs are loaded. if -1 (default), results from the most recent job (by datetime field) are loaded. @param calcs: the calcs argument is ignored. @return: None """ with self.get_session() as session: if jobs == -1: jobs = db_query.query_newest_job(session) changed_gens = self.changed_generations(session, jobs) self.result_data.reset_filters() self.result_data.levels = {'scan': -1} self.result_data.load_from_db(session, jobs=jobs) self.swarm_data = copy.copy(self.result_data) self.swarm_data.reset_filters() self.swarm_data.generations = changed_gens self.swarm_data.apply_filters() self.swarm_data.update_collections() if self._project: self.result_data.set_model_space(self._project.model_space) self.swarm_data.set_model_space(self._project.model_space) def create_report(self): """ generate the plots based on the stored attributes. this method essentially loops over generations and parameter combinations, and compiles the input for plot_swarm. combinations of parameter names are either set by the user in self.params or constructed from all non-degenerate (not constant) parameters. @return: list of created files """ # check that result data is compatible with swarm plots if self.swarm_data.params is None or len(self.swarm_data.params) < 2: logger.warning("result data must contain at least 2 parameters") return [] if self.swarm_data.generations is None or len(self.swarm_data.generations) < 1: logger.warning("result data must specify at least 1 generation") return [] if self.swarm_data.particles is None or len(self.swarm_data.particles) < 5: logger.warning("result data must specify at least 2 particles") return [] vmin = self.swarm_data.model_space.min vmax = self.swarm_data.model_space.max nd_params = self.swarm_data.non_degenerate_params() if self.params: ppairs = [p for p in self.params if len(p) == 2 and p[0] in nd_params and p[1] in nd_params] else: pnames = sorted(list(nd_params), key=str.lower) ppairs = [p for p in itertools.combinations(pnames, 2)] kwargs = {} if self.cmap is not None: kwargs['cmap'] = self.cmap if self.canvas is not None: kwargs['canvas'] = self.canvas files = [] fdict = {'base': self.base_filename} for rd in self.swarm_data.iterate_generations(): fdict['gen'] = int(rd.generations[0]) for ppair in ppairs: fdict['param0'] = ppair[0] fdict['param1'] = ppair[1] filename = Path(self.report_dir, self.filename_format) filename = Path(self.resolve_template(filename, fdict)) kwargs['title'] = self.resolve_template(self.title_format, fdict) params = {ppair[0]: [vmin[ppair[0]], vmax[ppair[0]]], ppair[1]: [vmin[ppair[1]], vmax[ppair[1]]]} of = plot_swarm(filename, rd.values, rd.deltas, self.result_data.values, params, **kwargs) if of: files.append(of) return files def render_swarm(output_file, values, deltas=None, model_space=None, generations=None, title=None, cmap=None, canvas=None): """ render a two-dimensional particle swarm population. this function generates a schematic rendering of a particle swarm in two dimensions. particles are represented by their position and velocity, indicated by an arrow. the model space is projected on the first two (or selected two) variable parameters. in the background, a scatter plot of results (dots with pseudocolor representing the R-factor) can be plotted. the chart type is designed for the particle swarm optimization algorithm. the function requires input in one of the following forms: - an open pmsco database session. the most recent job results are loaded. - a result (.dat) file or numpy structured array. the array must contain regular parameters, as well as the _particle and _gen columns, and optionally, the _rfac column. the function generates one chart per generation unless the generation argument is specified. velocities are not plotted. - position (.pos) and velocity (.vel) files or the respective numpy structured arrays. the arrays must contain regular parameters, as well as the `_particle` column. the result file must also contain an `_rfac` column. files are loaded by numpy.genfromtxt. - a pmsco.optimizers.population.Population object with valid data. the generation is taken from the respective attribute and overrides the function argument. the graphics file format can be changed by providing a specific canvas. default is PNG. this function requires the matplotlib module. if it is not available, the function raises an error. @note velocities are not stored in result files and can't be plotted if result files are used as input. to plot velocities, provide the database, .pos/.vel files or Population object as input. @param output_file: path and base name of the output file without extension. a generation index and the file extension according to the file format are appended. @param values: position data to plot. the input can be specified in several ways - see the main description for details. @param deltas: velocity data to plot (optional). the input can be specified in several ways - see the main description for details. @param model_space: model space can be a pmsco.project.ModelSpace object, any object that contains the same min and max attributes as pmsco.project.ModelSpace, or a dictionary with to keys 'min' and 'max' that provides the corresponding ModelSpace dictionaries. by default, the model space boundaries are derived from the input data. if a model_space is specified, only the parameters listed in it are plotted. @param title: (str) title of the chart. the title is a {}-style format string, where {base} is the output file name and {gen} is the generation. default: derived from file name. @param generations: (int or sequence) generation index or list of indices. this index is used in the output file name and for filtering input data by generation. if the input data does not contain the generation, no filtering is applied. by default, no filtering is applied, and one graph for each generation is produced. @param cmap: (str) name of colour map supported by matplotlib. default is 'plasma'. other good-looking options are 'viridis', 'plasma', 'inferno', 'magma', 'cividis'. @param canvas: a FigureCanvas class reference from a matplotlib backend. if None, the default FigureCanvasAgg is used which produces a bitmap file in PNG format. some other options are: matplotlib.backends.backend_pdf.FigureCanvasPdf or matplotlib.backends.backend_svg.FigureCanvasSVG. @return (list of str) paths of the generated graphics files. empty if an error occurred. @raise TypeError if matplotlib is not available. """ data = rp_results.ResultData() if isinstance(generations, int): generations = (generations,) data.generations = generations data.levels = {'scan': -1} data.load_any(values, deltas) if model_space is not None: data.set_model_space(model_space) plot = SwarmPlot() plot.canvas = canvas plot.cmap = cmap if title: plot.title_format = title else: plot.title_format = "${gen}" plot.report_dir = Path(output_file).parent plot.filename_format = Path(output_file).name + "-${param0}-${param1}-${gen}" plot.validate(None) plot.result_data = data plot.swarm_data = data files = plot.create_report() return files def main(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, description=""" swarm dynamics plot for multiple-scattering optimization results this module operates on results or database files and produces one graphics file per generation. database files contain the complete information for all plot types. data from the most recent job stored in the database is used. while .dat results files contain all data shown in genetic plots, they lack the velocity information for swarm plots. only particle positions are plotted in this case. .tasks.dat files lack the generation and particle identification and should not be used. note that the plot type is independent of the optimization mode. it's possible to generate genetic plots from a particle swarm optimization and vice versa. """) parser.add_argument('results_file', help="path to results file (.dat) or sqlite3 database file.") parser.add_argument('output_file', help="base name of output file. generation and extension will be appended.") parser.add_argument('-t', '--title', default=None, help='graph title. may contain {gen} as a placeholder for the generation number.') args, unknown_args = parser.parse_known_args() kwargs = {} if args.title is not None: kwargs['title'] = args.title render_func = render_swarm if db_util.is_sqlite3_file(args.results_file): import pmsco.database.access as db_access db = db_access.DatabaseAccess() db.connect(args.results_file) with db.session() as session: render_func(args.output_file, session, **kwargs) else: render_func(args.output_file, args.results_file, **kwargs) if __name__ == '__main__': main() sys.exit(0)