Files
pmsco-public/pmsco/reports/population.py

82 lines
2.7 KiB
Python

"""
@package pmsco.reports.population
common code for plotting population dynamics.
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
@copyright (c) 2021 by Paul Scherrer Institut @n
Licensed under the Apache License, Version 2.0 (the "License"); @n
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
"""
import logging
import numpy as np
from sqlalchemy import func as sql_func
import pmsco.database.util as db_util
import pmsco.database.orm as db_orm
import pmsco.database.query as db_query
logger = logging.getLogger(__name__)
class GenerationTracker(object):
"""
mixin for generation-based reports
the class keeps track of generations and identifies which generations have changed.
"""
## @var gens_parts
# generations and particle counts tracking
#
# this is a dictionary generation -> number of particles in generation
# which holds results from the last query of results.
def __init__(self):
super().__init__()
self.gens_parts = {}
def reset(self):
"""
reset tracking data
@return: None
"""
self.gens_parts = {}
def changed_generations(self, session, jobs):
"""
determine which generations have new results
the function queries the database for all generations (of the specified jobs) and their particle counts.
it returns the set of generations that has increased particle counts compared to the last call.
this indicates in which generations new results are available.
@param session: database session
@param jobs: job objects or ids to include in the query.
though the function accepts a sequence of jobs,
normally, this should be just one job id and the function should always be called with the same id.
@return: set of generation numbers that have new results since the last call.
"""
q = session.query(db_orm.Model.gen, sql_func.count(db_orm.Model.particle))
if jobs:
q = db_query.filter_objects(q, db_orm.Job, jobs)
q = q.join(db_orm.Result)
q = q.filter(db_orm.Result.scan == -1)
q = q.group_by(db_orm.Model.gen)
q_gens = q.all()
# logger.debug(f"changed_generations: {q.statement} ({len(q_gens)} rows)")
new_gens_parts = {g: p for g, p in q_gens}
changed = {g for g, p in q_gens if g not in self.gens_parts or self.gens_parts[g] < p}
# max_gen = max(new_gens_parts.keys())
# max_part = max(new_gens_parts.values())
logger.debug(f"changed generations: {changed}")
self.gens_parts = new_gens_parts
return changed