public release 4.2.0 - see README.md and CHANGES.md for details
This commit is contained in:
313
pmsco/reports/convergence.py
Executable file
313
pmsco/reports/convergence.py
Executable file
@@ -0,0 +1,313 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
@package pmsco.reports.convergence
|
||||
graphics rendering module to show convergence of a population.
|
||||
|
||||
the module can be used in several different ways:
|
||||
|
||||
1. via the command line on a pmsco database or .dat results file.
|
||||
this is the most simple but least flexible way.
|
||||
2. via python functions on given population arrays or database queries.
|
||||
this is the most flexible way but requires understanding of the required data formats.
|
||||
3. as a listener on calculation events. (to be implemented)
|
||||
this will be configurable in the run file.
|
||||
|
||||
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
|
||||
|
||||
@copyright (c) 2021 by Paul Scherrer Institut @n
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); @n
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
if __name__ == "__main__":
|
||||
pmsco_root = Path(__file__).resolve().parent.parent.parent
|
||||
if str(pmsco_root) not in sys.path:
|
||||
sys.path.insert(0, str(pmsco_root))
|
||||
|
||||
import pmsco.reports.results as rp_results
|
||||
import pmsco.database.util as db_util
|
||||
import pmsco.database.query as db_query
|
||||
from pmsco.reports.base import ProjectReport
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
|
||||
# from matplotlib.backends.backend_pdf import FigureCanvasPdf
|
||||
# from matplotlib.backends.backend_svg import FigureCanvasSVG
|
||||
except ImportError:
|
||||
Figure = None
|
||||
FigureCanvas = None
|
||||
MaxNLocator = None
|
||||
logger.warning("error importing matplotlib. graphics rendering disabled.")
|
||||
|
||||
|
||||
def plot_convergence(filename, data, title=None, canvas=None, num_gen=10):
|
||||
"""
|
||||
violin plot showing the convergence of a population by generation.
|
||||
|
||||
the plot is a violin plot where each violin represents one generation.
|
||||
the minimum, maximum and mean values are marked,
|
||||
and the distribution is indicated by the body.
|
||||
|
||||
if no generation index is available, the function can divide the data into a selected number of segments.
|
||||
|
||||
this is a low-level function containing just the plotting commands from numpy arrays.
|
||||
|
||||
the graphics file format can be changed by providing a specific canvas. default is PNG.
|
||||
|
||||
this function requires the matplotlib module.
|
||||
if it is not available, the function raises an error.
|
||||
|
||||
@param filename: path and base name of the output file without extension.
|
||||
a generation index and the file extension according to the file format are appended.
|
||||
@param data: structured ndarray containing generation numbers and R-factor values.
|
||||
the '_rfac' column is required and must contain R-factor values.
|
||||
the '_gen' column is optional. if present it must contain the generation index.
|
||||
if the '_gen' column is missing or if the array is a simple 1D array,
|
||||
the array is divided into num_gen segments.
|
||||
other columns may be present and are ignored.
|
||||
@param num_gen: number of generations if the '_gen' column is missing in data.
|
||||
@param title: (str) title of the chart.
|
||||
default: derived from parameter names.
|
||||
@param canvas: a FigureCanvas class reference from a matplotlib backend.
|
||||
if None, the default FigureCanvasAgg is used which produces a bitmap file in PNG format.
|
||||
some other options are:
|
||||
matplotlib.backends.backend_pdf.FigureCanvasPdf or
|
||||
matplotlib.backends.backend_svg.FigureCanvasSVG.
|
||||
|
||||
@return (str) path and name of the generated graphics file.
|
||||
None if no file was generated due to an error.
|
||||
"""
|
||||
if canvas is None:
|
||||
canvas = FigureCanvas
|
||||
if canvas is None or Figure is None:
|
||||
return None
|
||||
if title is None:
|
||||
title = 'convergence'
|
||||
|
||||
fig = Figure()
|
||||
canvas(fig)
|
||||
ax = fig.add_subplot(111)
|
||||
|
||||
rfactors = []
|
||||
generations = []
|
||||
try:
|
||||
generations = np.unique(data['_gen'])
|
||||
for gen in generations:
|
||||
idx = np.where(data['_gen'] == gen)
|
||||
rfactors.append(data['_rfac'][idx])
|
||||
except IndexError:
|
||||
for gen, arr in enumerate(np.array_split(data, num_gen)):
|
||||
generations.append(gen)
|
||||
rfactors.append(arr)
|
||||
except ValueError:
|
||||
for gen, arr in enumerate(np.array_split(data['_rfac'], num_gen)):
|
||||
generations.append(gen)
|
||||
rfactors.append(arr)
|
||||
|
||||
# the following may raise a VisibleDeprecationWarning.
|
||||
# this is a bug in matplotlib and has been resolved as of matplotlib 3.3.0.
|
||||
# BUG: VisibleDeprecationWarning in boxplot #16353
|
||||
# https://github.com/matplotlib/matplotlib/issues/16353
|
||||
ax.violinplot(rfactors, generations, showmeans=True, showextrema=True, showmedians=False, widths=0.8)
|
||||
ax.set_ylim([0., 1.])
|
||||
ax.set_xlabel('generation')
|
||||
ax.set_ylabel('R-factor')
|
||||
ax.set_title(title)
|
||||
|
||||
out_filename = "{base}.{ext}".format(base=filename, ext=canvas.get_default_filetype())
|
||||
try:
|
||||
fig.savefig(out_filename)
|
||||
except OSError:
|
||||
logger.exception(f"exception while saving figure {out_filename}")
|
||||
out_filename = None
|
||||
|
||||
return out_filename
|
||||
|
||||
|
||||
class ConvergencePlot(ProjectReport):
|
||||
"""
|
||||
violin plot showing the convergence of a population by generation.
|
||||
|
||||
this class collects and validates all parameters and data for generating a convergence plot.
|
||||
the convergence plot can be used to monitor the overall progress of an optimization job.
|
||||
|
||||
the convergence plot is a violin plot which shows the evolution of R-factors in a population over generations.
|
||||
each violin represents one generation.
|
||||
the minimum, maximum and mean values are marked,
|
||||
and the distribution is indicated by the body.
|
||||
|
||||
the graphics file format can be changed by providing a specific canvas. default is PNG.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._modes = ['genetic', 'swarm']
|
||||
self.result_data = rp_results.ResultData()
|
||||
self.filename_format = "${base}-convergence"
|
||||
self.title_format = "convergence"
|
||||
|
||||
def select_data(self, jobs=-1, calcs=None):
|
||||
"""
|
||||
query data from the database
|
||||
|
||||
this method must be implemented by the sub-class.
|
||||
|
||||
@param jobs: filter by job.
|
||||
the argument can be a singleton or sequence of orm.Job objects or numeric id.
|
||||
if None, results from all jobs are loaded.
|
||||
if -1 (default), results from the most recent job (by datetime field) are loaded.
|
||||
|
||||
@param calcs: the calcs argument is ignored.
|
||||
|
||||
@return: None
|
||||
"""
|
||||
|
||||
with self.get_session() as session:
|
||||
if jobs == -1:
|
||||
jobs = db_query.query_newest_job(session)
|
||||
self.result_data.reset_filters()
|
||||
self.result_data.levels = {'scan': -1}
|
||||
self.result_data.load_from_db(session, jobs=jobs, include_params=False)
|
||||
|
||||
def create_report(self):
|
||||
# check that result data is compatible with convergence plots
|
||||
if self.result_data.generations is None or len(self.result_data.generations) < 1:
|
||||
logger.warning("result data must specify at least 1 generation")
|
||||
return []
|
||||
if self.result_data.particles is None or len(self.result_data.particles) < 5:
|
||||
logger.warning("result data must specify at least 3 particles")
|
||||
return []
|
||||
|
||||
kwargs = {}
|
||||
if self.canvas is not None:
|
||||
kwargs['canvas'] = self.canvas
|
||||
|
||||
files = []
|
||||
fdict = {'base': self.base_filename}
|
||||
filename = Path(self.report_dir, self.filename_format)
|
||||
filename = Path(self.resolve_template(filename, fdict))
|
||||
kwargs['title'] = self.resolve_template(self.title_format, fdict)
|
||||
of = plot_convergence(filename, self.result_data.values, **kwargs)
|
||||
if of:
|
||||
files.append(of)
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def render_convergence(output_file, values, generations=None, title=None, canvas=None):
|
||||
"""
|
||||
produce a convergence plot of a population-based optimization job.
|
||||
|
||||
see ConvergencePlot and plot_convergence for details.
|
||||
|
||||
the function requires input in one of the following forms:
|
||||
- a result (.dat) file or numpy structured array.
|
||||
the array must contain the _gen and _rfac columns.
|
||||
other columns are ignored.
|
||||
the function generates one plot.
|
||||
- a pmsco.optimizers.population.Population object with valid data.
|
||||
the generation is taken from the respective attribute and overrides the function argument.
|
||||
- an open pmsco database session. the most recent job results are loaded.
|
||||
|
||||
the graphics file format can be changed by providing a specific canvas. default is PNG.
|
||||
|
||||
this function requires the matplotlib module.
|
||||
if it is not available, the function raises an error.
|
||||
|
||||
@param output_file: path and base name of the output file without extension.
|
||||
@param values: a numpy structured ndarray of a population or result list from an optimization run.
|
||||
alternatively, the file path of a result file (.dat) or population file (.pop) can be given.
|
||||
file can be any object that numpy.genfromtxt() can handle.
|
||||
array or file must be wrapped in a sequence.
|
||||
@param generations: (sequence of int) list of generation indices to filter.
|
||||
if the input data does not contain the generation, no filtering is applied.
|
||||
by default, no filtering is applied, and all generations are included in the plot.
|
||||
@param title: (str) title of the chart.
|
||||
the title is a {}-style format string, where {base} is the output file name.
|
||||
default: derived from file name.
|
||||
@param canvas: a FigureCanvas class reference from a matplotlib backend.
|
||||
if None, the default FigureCanvasAgg is used which produces a bitmap file in PNG format.
|
||||
some other options are:
|
||||
matplotlib.backends.backend_pdf.FigureCanvasPdf or
|
||||
matplotlib.backends.backend_svg.FigureCanvasSVG.
|
||||
|
||||
@return (list of str) paths of the generated graphics files.
|
||||
empty if an error occurred.
|
||||
"""
|
||||
|
||||
data = rp_results.ResultData()
|
||||
if isinstance(generations, int):
|
||||
generations = (generations,)
|
||||
data.generations = generations
|
||||
data.levels = {'scan': -1}
|
||||
data.load_any(values)
|
||||
|
||||
plot = ConvergencePlot()
|
||||
plot.canvas = canvas
|
||||
if title:
|
||||
plot.title_format = title
|
||||
else:
|
||||
plot.title_format = ""
|
||||
plot.report_dir = Path(output_file).parent
|
||||
plot.filename_format = Path(output_file).name
|
||||
plot.validate(None)
|
||||
plot.result_data = data
|
||||
files = plot.create_report()
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
description="""
|
||||
population convergence plot for multiple-scattering optimization results
|
||||
|
||||
this module operates on results or database files and produces one graphics file per generation.
|
||||
|
||||
database files contain the complete information for all plot types.
|
||||
data from the most recent job stored in the database is used.
|
||||
|
||||
.dat results files contain all necessary data.
|
||||
.tasks.dat files lack the generation and particle identification and should not be used.
|
||||
""")
|
||||
parser.add_argument('results_file',
|
||||
help="path to results file (.dat) or sqlite3 database file.")
|
||||
parser.add_argument('output_file',
|
||||
help="base name of output file. generation and extension will be appended.")
|
||||
parser.add_argument('-t', '--title', default=None,
|
||||
help='graph title. may contain {gen} as a placeholder for the generation number.')
|
||||
|
||||
args, unknown_args = parser.parse_known_args()
|
||||
|
||||
kwargs = {}
|
||||
if args.title is not None:
|
||||
kwargs['title'] = args.title
|
||||
|
||||
render_func = render_convergence
|
||||
|
||||
if db_util.is_sqlite3_file(args.results_file):
|
||||
import pmsco.database.access as db_access
|
||||
db = db_access.DatabaseAccess()
|
||||
db.connect(args.results_file)
|
||||
with db.session() as session:
|
||||
render_func(args.output_file, session, **kwargs)
|
||||
else:
|
||||
render_func(args.output_file, args.results_file, **kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
sys.exit(0)
|
||||
Reference in New Issue
Block a user