82 lines
2.7 KiB
Python
82 lines
2.7 KiB
Python
"""
|
|
@package pmsco.reports.population
|
|
common code for plotting population dynamics.
|
|
|
|
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
|
|
|
|
@copyright (c) 2021 by Paul Scherrer Institut @n
|
|
Licensed under the Apache License, Version 2.0 (the "License"); @n
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
"""
|
|
|
|
import logging
|
|
import numpy as np
|
|
from sqlalchemy import func as sql_func
|
|
import pmsco.database.util as db_util
|
|
import pmsco.database.orm as db_orm
|
|
import pmsco.database.query as db_query
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class GenerationTracker(object):
|
|
"""
|
|
mixin for generation-based reports
|
|
|
|
the class keeps track of generations and identifies which generations have changed.
|
|
"""
|
|
|
|
## @var gens_parts
|
|
# generations and particle counts tracking
|
|
#
|
|
# this is a dictionary generation -> number of particles in generation
|
|
# which holds results from the last query of results.
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.gens_parts = {}
|
|
|
|
def reset(self):
|
|
"""
|
|
reset tracking data
|
|
|
|
@return: None
|
|
"""
|
|
self.gens_parts = {}
|
|
|
|
def changed_generations(self, session, jobs):
|
|
"""
|
|
determine which generations have new results
|
|
|
|
the function queries the database for all generations (of the specified jobs) and their particle counts.
|
|
it returns the set of generations that has increased particle counts compared to the last call.
|
|
this indicates in which generations new results are available.
|
|
|
|
@param session: database session
|
|
@param jobs: job objects or ids to include in the query.
|
|
though the function accepts a sequence of jobs,
|
|
normally, this should be just one job id and the function should always be called with the same id.
|
|
@return: set of generation numbers that have new results since the last call.
|
|
"""
|
|
q = session.query(db_orm.Model.gen, sql_func.count(db_orm.Model.particle))
|
|
if jobs:
|
|
q = db_query.filter_objects(q, db_orm.Job, jobs)
|
|
q = q.join(db_orm.Result)
|
|
q = q.filter(db_orm.Result.scan == -1)
|
|
q = q.group_by(db_orm.Model.gen)
|
|
q_gens = q.all()
|
|
# logger.debug(f"changed_generations: {q.statement} ({len(q_gens)} rows)")
|
|
|
|
new_gens_parts = {g: p for g, p in q_gens}
|
|
changed = {g for g, p in q_gens if g not in self.gens_parts or self.gens_parts[g] < p}
|
|
# max_gen = max(new_gens_parts.keys())
|
|
# max_part = max(new_gens_parts.values())
|
|
logger.debug(f"changed generations: {changed}")
|
|
|
|
self.gens_parts = new_gens_parts
|
|
|
|
return changed
|