Files
pmsco-public/pmsco/reports/convergence.py

314 lines
12 KiB
Python
Executable File

#!/usr/bin/env python
"""
@package pmsco.reports.convergence
graphics rendering module to show convergence of a population.
the module can be used in several different ways:
1. via the command line on a pmsco database or .dat results file.
this is the most simple but least flexible way.
2. via python functions on given population arrays or database queries.
this is the most flexible way but requires understanding of the required data formats.
3. as a listener on calculation events. (to be implemented)
this will be configurable in the run file.
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
@copyright (c) 2021 by Paul Scherrer Institut @n
Licensed under the Apache License, Version 2.0 (the "License"); @n
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
"""
import argparse
import logging
import numpy as np
from pathlib import Path
import sys
if __name__ == "__main__":
pmsco_root = Path(__file__).resolve().parent.parent.parent
if str(pmsco_root) not in sys.path:
sys.path.insert(0, str(pmsco_root))
import pmsco.reports.results as rp_results
import pmsco.database.util as db_util
import pmsco.database.query as db_query
from pmsco.reports.base import ProjectReport
logger = logging.getLogger(__name__)
try:
from matplotlib.figure import Figure
from matplotlib.ticker import MaxNLocator
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
# from matplotlib.backends.backend_pdf import FigureCanvasPdf
# from matplotlib.backends.backend_svg import FigureCanvasSVG
except ImportError:
Figure = None
FigureCanvas = None
MaxNLocator = None
logger.warning("error importing matplotlib. graphics rendering disabled.")
def plot_convergence(filename, data, title=None, canvas=None, num_gen=10):
"""
violin plot showing the convergence of a population by generation.
the plot is a violin plot where each violin represents one generation.
the minimum, maximum and mean values are marked,
and the distribution is indicated by the body.
if no generation index is available, the function can divide the data into a selected number of segments.
this is a low-level function containing just the plotting commands from numpy arrays.
the graphics file format can be changed by providing a specific canvas. default is PNG.
this function requires the matplotlib module.
if it is not available, the function raises an error.
@param filename: path and base name of the output file without extension.
a generation index and the file extension according to the file format are appended.
@param data: structured ndarray containing generation numbers and R-factor values.
the '_rfac' column is required and must contain R-factor values.
the '_gen' column is optional. if present it must contain the generation index.
if the '_gen' column is missing or if the array is a simple 1D array,
the array is divided into num_gen segments.
other columns may be present and are ignored.
@param num_gen: number of generations if the '_gen' column is missing in data.
@param title: (str) title of the chart.
default: derived from parameter names.
@param canvas: a FigureCanvas class reference from a matplotlib backend.
if None, the default FigureCanvasAgg is used which produces a bitmap file in PNG format.
some other options are:
matplotlib.backends.backend_pdf.FigureCanvasPdf or
matplotlib.backends.backend_svg.FigureCanvasSVG.
@return (str) path and name of the generated graphics file.
None if no file was generated due to an error.
"""
if canvas is None:
canvas = FigureCanvas
if canvas is None or Figure is None:
return None
if title is None:
title = 'convergence'
fig = Figure()
canvas(fig)
ax = fig.add_subplot(111)
rfactors = []
generations = []
try:
generations = np.unique(data['_gen'])
for gen in generations:
idx = np.where(data['_gen'] == gen)
rfactors.append(data['_rfac'][idx])
except IndexError:
for gen, arr in enumerate(np.array_split(data, num_gen)):
generations.append(gen)
rfactors.append(arr)
except ValueError:
for gen, arr in enumerate(np.array_split(data['_rfac'], num_gen)):
generations.append(gen)
rfactors.append(arr)
# the following may raise a VisibleDeprecationWarning.
# this is a bug in matplotlib and has been resolved as of matplotlib 3.3.0.
# BUG: VisibleDeprecationWarning in boxplot #16353
# https://github.com/matplotlib/matplotlib/issues/16353
ax.violinplot(rfactors, generations, showmeans=True, showextrema=True, showmedians=False, widths=0.8)
ax.set_ylim([0., 1.])
ax.set_xlabel('generation')
ax.set_ylabel('R-factor')
ax.set_title(title)
out_filename = "{base}.{ext}".format(base=filename, ext=canvas.get_default_filetype())
try:
fig.savefig(out_filename)
except OSError:
logger.exception(f"exception while saving figure {out_filename}")
out_filename = None
return out_filename
class ConvergencePlot(ProjectReport):
"""
violin plot showing the convergence of a population by generation.
this class collects and validates all parameters and data for generating a convergence plot.
the convergence plot can be used to monitor the overall progress of an optimization job.
the convergence plot is a violin plot which shows the evolution of R-factors in a population over generations.
each violin represents one generation.
the minimum, maximum and mean values are marked,
and the distribution is indicated by the body.
the graphics file format can be changed by providing a specific canvas. default is PNG.
"""
def __init__(self):
super().__init__()
self._modes = ['genetic', 'swarm']
self.result_data = rp_results.ResultData()
self.filename_format = "${base}-convergence"
self.title_format = "convergence"
def select_data(self, jobs=-1, calcs=None):
"""
query data from the database
this method must be implemented by the sub-class.
@param jobs: filter by job.
the argument can be a singleton or sequence of orm.Job objects or numeric id.
if None, results from all jobs are loaded.
if -1 (default), results from the most recent job (by datetime field) are loaded.
@param calcs: the calcs argument is ignored.
@return: None
"""
with self.get_session() as session:
if jobs == -1:
jobs = db_query.query_newest_job(session)
self.result_data.reset_filters()
self.result_data.levels = {'scan': -1}
self.result_data.load_from_db(session, jobs=jobs, include_params=False)
def create_report(self):
# check that result data is compatible with convergence plots
if self.result_data.generations is None or len(self.result_data.generations) < 1:
logger.warning("result data must specify at least 1 generation")
return []
if self.result_data.particles is None or len(self.result_data.particles) < 5:
logger.warning("result data must specify at least 3 particles")
return []
kwargs = {}
if self.canvas is not None:
kwargs['canvas'] = self.canvas
files = []
fdict = {'base': self.base_filename}
filename = Path(self.report_dir, self.filename_format)
filename = Path(self.resolve_template(filename, fdict))
kwargs['title'] = self.resolve_template(self.title_format, fdict)
of = plot_convergence(filename, self.result_data.values, **kwargs)
if of:
files.append(of)
return files
def render_convergence(output_file, values, generations=None, title=None, canvas=None):
"""
produce a convergence plot of a population-based optimization job.
see ConvergencePlot and plot_convergence for details.
the function requires input in one of the following forms:
- a result (.dat) file or numpy structured array.
the array must contain the _gen and _rfac columns.
other columns are ignored.
the function generates one plot.
- a pmsco.optimizers.population.Population object with valid data.
the generation is taken from the respective attribute and overrides the function argument.
- an open pmsco database session. the most recent job results are loaded.
the graphics file format can be changed by providing a specific canvas. default is PNG.
this function requires the matplotlib module.
if it is not available, the function raises an error.
@param output_file: path and base name of the output file without extension.
@param values: a numpy structured ndarray of a population or result list from an optimization run.
alternatively, the file path of a result file (.dat) or population file (.pop) can be given.
file can be any object that numpy.genfromtxt() can handle.
array or file must be wrapped in a sequence.
@param generations: (sequence of int) list of generation indices to filter.
if the input data does not contain the generation, no filtering is applied.
by default, no filtering is applied, and all generations are included in the plot.
@param title: (str) title of the chart.
the title is a {}-style format string, where {base} is the output file name.
default: derived from file name.
@param canvas: a FigureCanvas class reference from a matplotlib backend.
if None, the default FigureCanvasAgg is used which produces a bitmap file in PNG format.
some other options are:
matplotlib.backends.backend_pdf.FigureCanvasPdf or
matplotlib.backends.backend_svg.FigureCanvasSVG.
@return (list of str) paths of the generated graphics files.
empty if an error occurred.
"""
data = rp_results.ResultData()
if isinstance(generations, int):
generations = (generations,)
data.generations = generations
data.levels = {'scan': -1}
data.load_any(values)
plot = ConvergencePlot()
plot.canvas = canvas
if title:
plot.title_format = title
else:
plot.title_format = ""
plot.report_dir = Path(output_file).parent
plot.filename_format = Path(output_file).name
plot.validate(None)
plot.result_data = data
files = plot.create_report()
return files
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="""
population convergence plot for multiple-scattering optimization results
this module operates on results or database files and produces one graphics file per generation.
database files contain the complete information for all plot types.
data from the most recent job stored in the database is used.
.dat results files contain all necessary data.
.tasks.dat files lack the generation and particle identification and should not be used.
""")
parser.add_argument('results_file',
help="path to results file (.dat) or sqlite3 database file.")
parser.add_argument('output_file',
help="base name of output file. generation and extension will be appended.")
parser.add_argument('-t', '--title', default=None,
help='graph title. may contain {gen} as a placeholder for the generation number.')
args, unknown_args = parser.parse_known_args()
kwargs = {}
if args.title is not None:
kwargs['title'] = args.title
render_func = render_convergence
if db_util.is_sqlite3_file(args.results_file):
import pmsco.database.access as db_access
db = db_access.DatabaseAccess()
db.connect(args.results_file)
with db.session() as session:
render_func(args.output_file, session, **kwargs)
else:
render_func(args.output_file, args.results_file, **kwargs)
if __name__ == '__main__':
main()
sys.exit(0)