public release 4.2.0 - see README.md and CHANGES.md for details
This commit is contained in:
@@ -1 +0,0 @@
|
||||
__author__ = 'muntwiler_m'
|
||||
232
tests/database/test_common.py
Normal file
232
tests/database/test_common.py
Normal file
@@ -0,0 +1,232 @@
|
||||
"""
|
||||
@package tests.database.test_common
|
||||
unit tests for pmsco.database.common
|
||||
|
||||
the purpose of these tests is to help debugging the code.
|
||||
|
||||
to run the tests, change to the directory which contains the tests directory, and execute =nosetests=.
|
||||
|
||||
@pre nose must be installed (python-nose package on Debian).
|
||||
|
||||
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
|
||||
|
||||
@copyright (c) 2016 by Paul Scherrer Institut @n
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); @n
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import datetime
|
||||
import sqlalchemy.exc
|
||||
import pmsco.database.access as db
|
||||
import pmsco.database.common as db_common
|
||||
import pmsco.database.orm as orm
|
||||
import pmsco.dispatch as dispatch
|
||||
|
||||
|
||||
def setup_sample_database(session):
|
||||
p1 = orm.Project(name="oldproject", code="oldcode")
|
||||
p2 = orm.Project(name="unittest", code="testcode")
|
||||
j1 = orm.Job(project=p1, name="oldjob", mode="oldmode", machine="oldhost", datetime=datetime.datetime.now())
|
||||
j2 = orm.Job(project=p2, name="testjob", mode="testmode", machine="testhost", datetime=datetime.datetime.now())
|
||||
pk1 = orm.Param(key='parA')
|
||||
pk2 = orm.Param(key='parB')
|
||||
pk3 = orm.Param(key='parC')
|
||||
m1 = orm.Model(job=j1, model=91)
|
||||
m2 = orm.Model(job=j2, model=92)
|
||||
r1 = orm.Result(calc_id=dispatch.CalcID(91, -1, -1, -1, -1), rfac=0.534, secs=37.9)
|
||||
r1.model = m1
|
||||
pv1 = orm.ParamValue(model=m1, param=pk1, value=1.234, delta=0.1234)
|
||||
pv2 = orm.ParamValue(model=m1, param=pk2, value=5.678, delta=-0.5678)
|
||||
pv3 = orm.ParamValue(model=m2, param=pk3, value=6.785, delta=0.6785)
|
||||
objects = {'p1': p1, 'p2': p2, 'j1': j1, 'j2': j2, 'm1': m1, 'm2': m2, 'r1': r1,
|
||||
'pv1': pv1, 'pv2': pv2, 'pv3': pv3, 'pk1': pk1, 'pk2': pk2, 'pk3': pk3}
|
||||
session.add_all(objects.values())
|
||||
session.commit()
|
||||
return objects
|
||||
|
||||
|
||||
class TestDatabaseCommon(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.db = db.DatabaseAccess()
|
||||
self.db.connect(":memory:")
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
# before any methods in this class
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
# teardown_class() after any methods in this class
|
||||
pass
|
||||
|
||||
def test_setup_database(self):
|
||||
with self.db.session() as session:
|
||||
setup_sample_database(session)
|
||||
self.assertEqual(session.query(orm.Project).count(), 2)
|
||||
self.assertEqual(session.query(orm.Job).count(), 2)
|
||||
self.assertEqual(session.query(orm.Param).count(), 3)
|
||||
self.assertEqual(session.query(orm.Model).count(), 2)
|
||||
self.assertEqual(session.query(orm.Result).count(), 1)
|
||||
self.assertEqual(session.query(orm.ParamValue).count(), 3)
|
||||
|
||||
def test_get_project(self):
|
||||
with self.db.session() as session:
|
||||
p1 = orm.Project(name="p1")
|
||||
p2 = orm.Project(name="p2")
|
||||
p3 = orm.Project(name="p3")
|
||||
p4 = orm.Project(name="p4")
|
||||
session.add_all([p1, p2, p3])
|
||||
session.commit()
|
||||
q1 = db_common.get_project(session, p1)
|
||||
q2 = db_common.get_project(session, p2.id)
|
||||
q3 = db_common.get_project(session, p3.name)
|
||||
q4 = db_common.get_project(session, p4)
|
||||
self.assertIs(q1, p1, "by object")
|
||||
self.assertIs(q2, p2, "by id")
|
||||
self.assertIs(q3, p3, "by name")
|
||||
self.assertIs(q4, p4, "detached object by object")
|
||||
with self.assertRaises(sqlalchemy.exc.InvalidRequestError, msg="detached object by name"):
|
||||
db_common.get_project(session, p4.name)
|
||||
|
||||
def test_get_job(self):
|
||||
with self.db.session() as session:
|
||||
p1 = orm.Project(name="p1")
|
||||
p2 = orm.Project(name="p2")
|
||||
p3 = orm.Project(name="p3")
|
||||
p4 = orm.Project(name="p4")
|
||||
j1 = orm.Job(name="j1")
|
||||
j1.project = p1
|
||||
j2 = orm.Job(name="j2")
|
||||
j2.project = p2
|
||||
j3 = orm.Job(name="j1")
|
||||
j3.project = p3
|
||||
j4 = orm.Job(name="j4")
|
||||
j4.project = p4
|
||||
session.add_all([p1, p2, p3, j1, j2, j3])
|
||||
session.commit()
|
||||
|
||||
self.assertIsNot(j3, j1, "jobs with same name")
|
||||
q1 = db_common.get_job(session, p1, j1)
|
||||
q2 = db_common.get_job(session, p2, j2.id)
|
||||
q3 = db_common.get_job(session, p3, j3.name)
|
||||
q4 = db_common.get_job(session, p4, j4)
|
||||
self.assertIs(q1, j1, "by object")
|
||||
self.assertIs(q2, j2, "by id")
|
||||
self.assertIs(q3, j3, "by name")
|
||||
self.assertIs(q4, j4, "detached object by object")
|
||||
with self.assertRaises(sqlalchemy.exc.InvalidRequestError, msg="detached object by name"):
|
||||
db_common.get_job(session, p4, j4.name)
|
||||
q5 = db_common.get_job(session, p1, j4)
|
||||
self.assertIs(q5, j4)
|
||||
|
||||
def test_register_project(self):
|
||||
with self.db.session() as session:
|
||||
id1 = db_common.register_project(session, "unittest1", "Atest", allow_existing=True)
|
||||
self.assertIsInstance(id1, orm.Project)
|
||||
id2 = db_common.register_project(session, "unittest2", "Btest", allow_existing=True)
|
||||
self.assertIsInstance(id2, orm.Project)
|
||||
id3 = db_common.register_project(session, "unittest1", "Ctest", allow_existing=True)
|
||||
self.assertIsInstance(id3, orm.Project)
|
||||
self.assertNotEqual(id1, id2)
|
||||
self.assertEqual(id1, id3)
|
||||
session.commit()
|
||||
|
||||
c = session.execute("select count(*) from Projects")
|
||||
row = c.fetchone()
|
||||
self.assertEqual(row[0], 2)
|
||||
c = session.execute("select name, code from Projects where id=:id", {'id': id1.id})
|
||||
row = c.fetchone()
|
||||
self.assertIsNotNone(row)
|
||||
self.assertEqual(len(row), 2)
|
||||
self.assertEqual(row[0], "unittest1")
|
||||
self.assertEqual(row[1], "Atest")
|
||||
self.assertEqual(row['name'], "unittest1")
|
||||
self.assertEqual(row['code'], "Atest")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
db_common.register_project(session, "unittest1", "Ctest")
|
||||
|
||||
def test_register_job(self):
|
||||
with self.db.session() as session:
|
||||
pid1 = db_common.register_project(session, "unittest1", "Acode")
|
||||
pid2 = db_common.register_project(session, "unittest2", "Bcode")
|
||||
dt1 = datetime.datetime.now()
|
||||
|
||||
# insert new job
|
||||
id1 = db_common.register_job(session, pid1, "Ajob", mode="Amode", machine="local", git_hash="Ahash",
|
||||
datetime=dt1, description="Adesc")
|
||||
self.assertIsInstance(id1, orm.Job)
|
||||
# insert another job
|
||||
id2 = db_common.register_job(session, pid1.id, "Bjob", mode="Bmode", machine="local", git_hash="Ahash",
|
||||
datetime=dt1, description="Adesc")
|
||||
self.assertIsInstance(id2, orm.Job)
|
||||
# update first job
|
||||
id3 = db_common.register_job(session, "unittest1", "Ajob", mode="Cmode", machine="local", git_hash="Chash",
|
||||
datetime=dt1, description="Cdesc",
|
||||
allow_existing=True)
|
||||
self.assertIsInstance(id3, orm.Job)
|
||||
# insert another job with same name but in other project
|
||||
id4 = db_common.register_job(session, pid2, "Ajob", mode="Dmode", machine="local", git_hash="Dhash",
|
||||
datetime=dt1, description="Ddesc")
|
||||
self.assertIsInstance(id4, orm.Job)
|
||||
# existing job
|
||||
with self.assertRaises(ValueError):
|
||||
db_common.register_job(session, pid1, "Ajob", mode="Emode", machine="local", git_hash="Dhash",
|
||||
datetime=dt1, description="Ddesc")
|
||||
|
||||
self.assertIsNot(id1, id2)
|
||||
self.assertIs(id1, id3)
|
||||
self.assertIsNot(id1, id4)
|
||||
|
||||
c = session.execute("select count(*) from Jobs")
|
||||
row = c.fetchone()
|
||||
self.assertEqual(row[0], 3)
|
||||
c = session.execute("select name, mode, machine, git_hash, datetime, description from Jobs where id=:id",
|
||||
{'id': id1.id})
|
||||
row = c.fetchone()
|
||||
self.assertIsNotNone(row)
|
||||
self.assertEqual(len(row), 6)
|
||||
self.assertEqual(row[0], "Ajob")
|
||||
self.assertEqual(row[1], "Amode")
|
||||
self.assertEqual(row['machine'], "local")
|
||||
self.assertEqual(str(row['datetime']), str(dt1))
|
||||
self.assertEqual(row['git_hash'], "Ahash")
|
||||
self.assertEqual(row['description'], "Adesc")
|
||||
|
||||
def test_register_params(self):
|
||||
with self.db.session() as session:
|
||||
setup_sample_database(session)
|
||||
model5 = {'parA': 2.341, 'parC': 6.785, '_model': 92, '_rfac': 0.453}
|
||||
db_common.register_params(session, model5)
|
||||
expected = ['parA', 'parB', 'parC']
|
||||
session.commit()
|
||||
|
||||
c = session.execute("select * from Params order by key")
|
||||
results = c.fetchall()
|
||||
self.assertEqual(len(results), 3)
|
||||
result_params = [row['key'] for row in results]
|
||||
self.assertEqual(result_params, expected)
|
||||
|
||||
def test_query_params(self):
|
||||
with self.db.session() as session:
|
||||
objs = setup_sample_database(session)
|
||||
results = db_common.query_params(session, project=objs['p1'].id)
|
||||
expected = ['parA', 'parB']
|
||||
self.assertEqual(expected, sorted(list(results.keys())))
|
||||
self.assertIsInstance(results['parA'], orm.Param)
|
||||
self.assertIsInstance(results['parB'], orm.Param)
|
||||
results = db_common.query_params(session, project=objs['p2'].name)
|
||||
expected = ['parC']
|
||||
self.assertEqual(expected, sorted(list(results.keys())))
|
||||
self.assertIsInstance(results['parC'], orm.Param)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
211
tests/database/test_ingest.py
Normal file
211
tests/database/test_ingest.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
@package tests.database.test_ingest
|
||||
unit tests for pmsco.database
|
||||
|
||||
the purpose of these tests is to help debugging the code.
|
||||
|
||||
to run the tests, change to the directory which contains the tests directory, and execute =nosetests=.
|
||||
|
||||
@pre nose must be installed (python-nose package on Debian).
|
||||
|
||||
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
|
||||
|
||||
@copyright (c) 2016 by Paul Scherrer Institut @n
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); @n
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import pmsco.database.access as db
|
||||
import pmsco.database.ingest as db_ingest
|
||||
import pmsco.database.orm as orm
|
||||
import pmsco.dispatch as dispatch
|
||||
from tests.database.test_common import setup_sample_database
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.db = db.DatabaseAccess()
|
||||
self.db.connect(":memory:")
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
# before any methods in this class
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
# teardown_class() after any methods in this class
|
||||
pass
|
||||
|
||||
def test_insert_result(self):
|
||||
with self.db.session() as session:
|
||||
objs = setup_sample_database(session)
|
||||
index = dispatch.CalcID(15, 16, 17, 18, -1)
|
||||
result_data = {'parA': 4.123, 'parB': 8.567, '_rfac': 0.654, '_gen': 3, '_particle': 21, '_secs': 27.8}
|
||||
result_delta = {'parA': 0.4123, 'parB': 0.8567}
|
||||
model_obj, result_obj = db_ingest.insert_result(session, objs['j1'], index, result_data, result_delta)
|
||||
session.commit()
|
||||
|
||||
# model
|
||||
q = session.query(orm.Model)
|
||||
q = q.filter(orm.Model.job_id == objs['j1'].id)
|
||||
q = q.filter(orm.Model.model == index.model)
|
||||
m = q.one()
|
||||
self.assertIsNot(model_obj, objs['m1'])
|
||||
self.assertIs(m, model_obj)
|
||||
self.assertEqual(m.id, model_obj.id)
|
||||
self.assertEqual(m.job_id, objs['j1'].id)
|
||||
self.assertEqual(m.model, index.model)
|
||||
self.assertEqual(m.gen, result_data['_gen'])
|
||||
self.assertEqual(m.particle, result_data['_particle'])
|
||||
|
||||
# result
|
||||
q = session.query(orm.Result)
|
||||
q = q.filter(orm.Result.model_id == model_obj.id)
|
||||
r = q.one()
|
||||
self.assertIsNot(r, objs['r1'])
|
||||
self.assertIs(r, result_obj)
|
||||
self.assertEqual(r.id, result_obj.id)
|
||||
self.assertIs(r.model, model_obj)
|
||||
self.assertEqual(r.scan, index.scan)
|
||||
self.assertEqual(r.domain, index.domain)
|
||||
self.assertEqual(r.emit, index.emit)
|
||||
self.assertEqual(r.region, index.region)
|
||||
self.assertEqual(r.rfac, result_data['_rfac'])
|
||||
self.assertEqual(r.secs, result_data['_secs'])
|
||||
|
||||
# param values
|
||||
q = session.query(orm.ParamValue)
|
||||
q = q.filter(orm.ParamValue.model_id == model_obj.id)
|
||||
pvs = q.all()
|
||||
values = {pv.param_key: pv.value for pv in pvs}
|
||||
deltas = {pv.param_key: pv.delta for pv in pvs}
|
||||
for k in result_data:
|
||||
if k[0] != '_':
|
||||
self.assertAlmostEqual(values[k], result_data[k])
|
||||
self.assertAlmostEqual(deltas[k], result_delta[k])
|
||||
self.assertAlmostEqual(m.values[k], result_data[k])
|
||||
self.assertAlmostEqual(m.deltas[k], result_delta[k])
|
||||
|
||||
def test_update_result(self):
|
||||
"""
|
||||
test update an existing model and result
|
||||
|
||||
update parameters parA and parB and rfac of result (91, -1, -1, -1, -1)
|
||||
|
||||
@return: None
|
||||
"""
|
||||
with self.db.session() as session:
|
||||
objs = setup_sample_database(session)
|
||||
index = dispatch.CalcID(91, -1, -1, -1, -1)
|
||||
result_data = {'parA': 4.123, 'parB': 8.567, '_rfac': 0.654, '_gen': 3, '_particle': 21, '_secs': 27.8}
|
||||
result_delta = {'parA': 0.4123, 'parB': 0.8567}
|
||||
model_obj, result_obj = db_ingest.insert_result(session, objs['j1'], index, result_data, result_delta)
|
||||
session.commit()
|
||||
|
||||
# model
|
||||
q = session.query(orm.Model)
|
||||
q = q.filter(orm.Model.job_id == objs['j1'].id)
|
||||
q = q.filter(orm.Model.model == index.model)
|
||||
m = q.one()
|
||||
self.assertIs(model_obj, objs['m1'])
|
||||
self.assertIs(m, objs['m1'])
|
||||
self.assertEqual(m.id, model_obj.id)
|
||||
self.assertEqual(m.job_id, objs['j1'].id)
|
||||
self.assertEqual(m.model, index.model)
|
||||
self.assertEqual(m.gen, result_data['_gen'])
|
||||
self.assertEqual(m.particle, result_data['_particle'])
|
||||
|
||||
# result
|
||||
q = session.query(orm.Result)
|
||||
q = q.filter(orm.Result.model_id == model_obj.id)
|
||||
r = q.one()
|
||||
self.assertIs(result_obj, objs['r1'])
|
||||
self.assertIs(r, objs['r1'])
|
||||
self.assertEqual(r.id, result_obj.id)
|
||||
self.assertIs(r.model, model_obj)
|
||||
self.assertEqual(r.scan, index.scan)
|
||||
self.assertEqual(r.domain, index.domain)
|
||||
self.assertEqual(r.emit, index.emit)
|
||||
self.assertEqual(r.region, index.region)
|
||||
self.assertEqual(r.rfac, result_data['_rfac'])
|
||||
self.assertEqual(r.secs, result_data['_secs'])
|
||||
|
||||
# param values
|
||||
q = session.query(orm.ParamValue)
|
||||
q = q.filter(orm.ParamValue.model_id == model_obj.id)
|
||||
pvs = q.all()
|
||||
values = {pv.param_key: pv.value for pv in pvs}
|
||||
deltas = {pv.param_key: pv.delta for pv in pvs}
|
||||
for k in result_data:
|
||||
if k[0] != '_':
|
||||
self.assertAlmostEqual(values[k], result_data[k])
|
||||
self.assertAlmostEqual(deltas[k], result_delta[k])
|
||||
self.assertAlmostEqual(m.values[k], result_data[k])
|
||||
self.assertAlmostEqual(m.deltas[k], result_delta[k])
|
||||
|
||||
def test_update_result_dict(self):
|
||||
"""
|
||||
test update an existing model and result with dictionary arguments
|
||||
|
||||
update parameters parA and parB and rfac of result (91, -1, -1, -1, -1)
|
||||
|
||||
@return: None
|
||||
"""
|
||||
with self.db.session() as session:
|
||||
objs = setup_sample_database(session)
|
||||
result_data = {'_model': 91, '_scan': -1, '_domain': -1, '_emit': -1, '_region': -1,
|
||||
'parA': 4.123, 'parB': 8.567, '_rfac': 0.654, '_gen': 3, '_particle': 21, '_secs': 27.8}
|
||||
result_delta = {'parA': 0.4123, 'parB': 0.8567}
|
||||
model_obj, result_obj = db_ingest.insert_result(session, objs['j1'], result_data, result_data, result_delta)
|
||||
session.commit()
|
||||
|
||||
# model
|
||||
q = session.query(orm.Model)
|
||||
q = q.filter(orm.Model.job_id == objs['j1'].id)
|
||||
q = q.filter(orm.Model.model == result_data['_model'])
|
||||
m = q.one()
|
||||
self.assertIs(model_obj, objs['m1'])
|
||||
self.assertIs(m, objs['m1'])
|
||||
self.assertEqual(m.id, model_obj.id)
|
||||
self.assertEqual(m.job_id, objs['j1'].id)
|
||||
self.assertEqual(m.model, result_data['_model'])
|
||||
self.assertEqual(m.gen, result_data['_gen'])
|
||||
self.assertEqual(m.particle, result_data['_particle'])
|
||||
|
||||
# result
|
||||
q = session.query(orm.Result)
|
||||
q = q.filter(orm.Result.model_id == model_obj.id)
|
||||
r = q.one()
|
||||
self.assertIs(result_obj, objs['r1'])
|
||||
self.assertIs(r, objs['r1'])
|
||||
self.assertEqual(r.id, result_obj.id)
|
||||
self.assertIs(r.model, model_obj)
|
||||
self.assertEqual(r.scan, result_data['_scan'])
|
||||
self.assertEqual(r.domain, result_data['_domain'])
|
||||
self.assertEqual(r.emit, result_data['_emit'])
|
||||
self.assertEqual(r.region, result_data['_region'])
|
||||
self.assertEqual(r.rfac, result_data['_rfac'])
|
||||
|
||||
# param values
|
||||
q = session.query(orm.ParamValue)
|
||||
q = q.filter(orm.ParamValue.model_id == model_obj.id)
|
||||
pvs = q.all()
|
||||
values = {pv.param_key: pv.value for pv in pvs}
|
||||
deltas = {pv.param_key: pv.delta for pv in pvs}
|
||||
for k in result_data:
|
||||
if k[0] != '_':
|
||||
self.assertAlmostEqual(values[k], result_data[k])
|
||||
self.assertAlmostEqual(deltas[k], result_delta[k])
|
||||
self.assertAlmostEqual(m.values[k], result_data[k])
|
||||
self.assertAlmostEqual(m.deltas[k], result_delta[k])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
279
tests/database/test_orm.py
Normal file
279
tests/database/test_orm.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""
|
||||
@package tests.database.test_orm
|
||||
unit tests for pmsco.database.orm
|
||||
|
||||
the purpose of these tests is to help debugging the code.
|
||||
|
||||
to run the tests, change to the directory which contains the tests directory, and execute =nosetests=.
|
||||
|
||||
@pre nose must be installed (python-nose package on Debian).
|
||||
|
||||
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
|
||||
|
||||
@copyright (c) 2021 by Paul Scherrer Institut @n
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); @n
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import pmsco.database.access as db
|
||||
import pmsco.database.orm as orm
|
||||
import pmsco.database.util as util
|
||||
import pmsco.dispatch as dispatch
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.db = db.DatabaseAccess()
|
||||
self.db.connect(":memory:")
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
# before any methods in this class
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
# teardown_class() after any methods in this class
|
||||
pass
|
||||
|
||||
def test_orm_1(self):
|
||||
with self.db.session() as session:
|
||||
prj = orm.Project(name="test1", code=__file__)
|
||||
session.add(prj)
|
||||
job = orm.Job(name="test_database")
|
||||
job.project = prj
|
||||
session.add(job)
|
||||
tag1 = orm.Tag(key="phase")
|
||||
tag2 = orm.Tag(key="scatter")
|
||||
session.add_all([tag1, tag2])
|
||||
jt1 = orm.JobTag()
|
||||
jt1.tag = tag1
|
||||
jt1.job = job
|
||||
jt1.value = 'phagen'
|
||||
jt2 = orm.JobTag()
|
||||
jt2.tag = tag2
|
||||
jt2.job = job
|
||||
jt2.value = 'edac'
|
||||
session.commit()
|
||||
|
||||
qprj = session.query(orm.Project).filter_by(id=1).one()
|
||||
self.assertEqual(prj.name, qprj.name)
|
||||
qjob = session.query(orm.Job).filter_by(id=1).one()
|
||||
self.assertEqual(job.name, qjob.name)
|
||||
self.assertEqual(job.project.name, prj.name)
|
||||
self.assertEqual(len(qprj.jobs), 1)
|
||||
self.assertEqual(len(qjob.job_tags), 2)
|
||||
self.assertEqual(qjob.tags['phase'], 'phagen')
|
||||
self.assertEqual(qjob.tags['scatter'], 'edac')
|
||||
|
||||
def test_orm_2(self):
|
||||
with self.db.session() as session:
|
||||
prj = orm.Project(name="project 1", code=__file__)
|
||||
session.add(prj)
|
||||
|
||||
job = orm.Job(name="job 1")
|
||||
job.project = prj
|
||||
session.add(job)
|
||||
|
||||
jt1 = orm.JobTag('phase', 'phagen')
|
||||
session.add(jt1)
|
||||
job.job_tags[jt1.tag_key] = jt1
|
||||
job.tags['scatter'] = 'edac'
|
||||
|
||||
mod = orm.Model(model=1111, gen=111, particle=11)
|
||||
session.add(mod)
|
||||
|
||||
pv1 = orm.ParamValue(key='dAB', value=123.456, delta=7.543)
|
||||
session.add(pv1)
|
||||
mod.param_values[pv1.param_key] = pv1
|
||||
mod.values['dBC'] = 234.567
|
||||
|
||||
cid = dispatch.CalcID(1111, 2, 3, 4, 5)
|
||||
res = orm.Result(calc_id=cid, rfac=0.123)
|
||||
res.model = mod
|
||||
session.add(res)
|
||||
|
||||
session.commit()
|
||||
|
||||
qprj = session.query(orm.Project).filter_by(id=1).one()
|
||||
self.assertEqual(qprj.name, prj.name)
|
||||
self.assertEqual(len(qprj.jobs), 1)
|
||||
job_names = [k for k in qprj.jobs.keys()]
|
||||
self.assertEqual(job_names[0], job.name)
|
||||
self.assertEqual(qprj.jobs[job.name], job)
|
||||
|
||||
qjob = session.query(orm.Job).filter_by(id=1).one()
|
||||
self.assertEqual(qjob.name, job.name)
|
||||
self.assertEqual(qjob.project.name, prj.name)
|
||||
self.assertEqual(len(qjob.job_tags), 2)
|
||||
self.assertEqual(qjob.job_tags['phase'].value, 'phagen')
|
||||
self.assertEqual(qjob.job_tags['scatter'].value, 'edac')
|
||||
self.assertEqual(len(qjob.tags), 2)
|
||||
self.assertEqual(qjob.tags['phase'], 'phagen')
|
||||
self.assertEqual(qjob.tags['scatter'], 'edac')
|
||||
|
||||
qmod = session.query(orm.Model).filter_by(id=1).one()
|
||||
self.assertEqual(qmod.model, mod.model)
|
||||
self.assertEqual(len(qmod.param_values), 2)
|
||||
self.assertEqual(qmod.values['dAB'], 123.456)
|
||||
self.assertEqual(qmod.deltas['dAB'], 7.543)
|
||||
self.assertEqual(qmod.values['dBC'], 234.567)
|
||||
|
||||
self.assertEqual(len(qmod.results), 1)
|
||||
self.assertEqual(qmod.results[0].rfac, 0.123)
|
||||
|
||||
def test_job_tags(self):
|
||||
with self.db.session() as session:
|
||||
prj = orm.Project(name="project 1", code=__file__)
|
||||
session.add(prj)
|
||||
|
||||
job1 = orm.Job(name="job 1")
|
||||
job1.project = prj
|
||||
session.add(job1)
|
||||
job2 = orm.Job(name="job 2")
|
||||
job2.project = prj
|
||||
session.add(job2)
|
||||
|
||||
job1.tags['color'] = 'blue'
|
||||
job1.tags['shape'] = 'round'
|
||||
session.flush()
|
||||
job2.tags['color'] = 'red'
|
||||
job1.tags['color'] = 'green'
|
||||
|
||||
session.commit()
|
||||
|
||||
qjob1 = session.query(orm.Job).filter_by(name='job 1').one()
|
||||
self.assertEqual(qjob1.tags['color'], 'green')
|
||||
qjob2 = session.query(orm.Job).filter_by(name='job 2').one()
|
||||
self.assertEqual(qjob2.tags['color'], 'red')
|
||||
|
||||
def test_job_jobtags(self):
|
||||
with self.db.session() as session:
|
||||
prj = orm.Project(name="project 1", code=__file__)
|
||||
session.add(prj)
|
||||
|
||||
job1 = orm.Job(name="job 1")
|
||||
job1.project = prj
|
||||
session.add(job1)
|
||||
job2 = orm.Job(name="job 2")
|
||||
job2.project = prj
|
||||
session.add(job2)
|
||||
|
||||
jt1 = orm.JobTag('color', 'blue')
|
||||
job1.job_tags[jt1.tag_key] = jt1
|
||||
session.flush()
|
||||
jt2 = orm.JobTag('color', 'red')
|
||||
job2.job_tags[jt2.tag_key] = jt2
|
||||
|
||||
session.commit()
|
||||
|
||||
qjob1 = session.query(orm.Job).filter_by(name='job 1').one()
|
||||
self.assertIsInstance(qjob1.job_tags['color'], orm.JobTag)
|
||||
self.assertEqual(qjob1.job_tags['color'].value, 'blue')
|
||||
qjob2 = session.query(orm.Job).filter_by(name='job 2').one()
|
||||
self.assertIsInstance(qjob2.job_tags['color'], orm.JobTag)
|
||||
self.assertEqual(qjob2.job_tags['color'].value, 'red')
|
||||
|
||||
def test_param_values(self):
|
||||
with self.db.session() as session:
|
||||
prj = orm.Project(name="project 1", code=__file__)
|
||||
session.add(prj)
|
||||
job = orm.Job(name="job 1")
|
||||
job.project = prj
|
||||
session.add(job)
|
||||
|
||||
mod1 = orm.Model(model=1, gen=11, particle=111)
|
||||
session.add(mod1)
|
||||
mod2 = orm.Model(model=2, gen=22, particle=222)
|
||||
session.add(mod2)
|
||||
|
||||
mod1.values['dBC'] = 234.567
|
||||
# note: this flush is necessary before accessing the same param in another model
|
||||
session.flush()
|
||||
mod2.values['dBC'] = 345.678
|
||||
|
||||
session.commit()
|
||||
|
||||
qmod1 = session.query(orm.Model).filter_by(model=1).one()
|
||||
self.assertEqual(qmod1.values['dBC'], 234.567)
|
||||
qmod2 = session.query(orm.Model).filter_by(model=2).one()
|
||||
self.assertEqual(qmod2.values['dBC'], 345.678)
|
||||
|
||||
def test_filter_job(self):
|
||||
"""
|
||||
test sqlalchemy filter syntax
|
||||
|
||||
@return: None
|
||||
"""
|
||||
with self.db.session() as session:
|
||||
p1 = orm.Project(name="p1")
|
||||
p2 = orm.Project(name="p2")
|
||||
j11 = orm.Job(name="j1")
|
||||
j11.project = p1
|
||||
j12 = orm.Job(name="j2")
|
||||
j12.project = p1
|
||||
j21 = orm.Job(name="j1")
|
||||
j21.project = p2
|
||||
j22 = orm.Job(name="j2")
|
||||
j22.project = p2
|
||||
session.add_all([p1, p2, j11, j12, j21, j22])
|
||||
session.commit()
|
||||
|
||||
q1 = session.query(orm.Job).join(orm.Project)
|
||||
q1 = q1.filter(orm.Project.name == 'p1')
|
||||
q1 = q1.filter(orm.Job.name == 'j1')
|
||||
jobs1 = q1.all()
|
||||
|
||||
sql = """
|
||||
select Projects.name project_name, Jobs.name job_name
|
||||
from Projects join Jobs on Projects.id = Jobs.project_id
|
||||
where Jobs.name = 'j1' and Projects.name = 'p1'
|
||||
"""
|
||||
jobs2 = session.execute(sql)
|
||||
|
||||
n = 0
|
||||
for j in jobs2:
|
||||
self.assertEqual(j.project_name, 'p1')
|
||||
self.assertEqual(j.job_name, 'j1')
|
||||
n += 1
|
||||
self.assertEqual(n, 1)
|
||||
|
||||
for j in jobs1:
|
||||
self.assertEqual(j.project.name, 'p1')
|
||||
self.assertEqual(j.name, 'j1')
|
||||
self.assertEqual(len(jobs1), 1)
|
||||
|
||||
def test_filter_in(self):
|
||||
"""
|
||||
test sqlalchemy filter syntax: in_ operator
|
||||
|
||||
@return: None
|
||||
"""
|
||||
with self.db.session() as session:
|
||||
p1 = orm.Project(name="p1")
|
||||
p2 = orm.Project(name="p2")
|
||||
j11 = orm.Job(name="j1")
|
||||
j11.project = p1
|
||||
j12 = orm.Job(name="j2")
|
||||
j12.project = p1
|
||||
j21 = orm.Job(name="j1")
|
||||
j21.project = p2
|
||||
j22 = orm.Job(name="j2")
|
||||
j22.project = p2
|
||||
session.add_all([p1, p2, j11, j12, j21, j22])
|
||||
session.commit()
|
||||
|
||||
q1 = session.query(orm.Job)
|
||||
q1 = q1.filter(orm.Job.id.in_([2, 3, 7]))
|
||||
jobs1 = q1.all()
|
||||
self.assertEqual(len(jobs1), 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
311
tests/database/test_query.py
Normal file
311
tests/database/test_query.py
Normal file
@@ -0,0 +1,311 @@
|
||||
"""
|
||||
@package tests.database.test_query
|
||||
unit tests for pmsco.database
|
||||
|
||||
the purpose of these tests is to help debugging the code.
|
||||
|
||||
to run the tests, change to the directory which contains the tests directory, and execute =nosetests=.
|
||||
|
||||
@pre nose must be installed (python-nose package on Debian).
|
||||
|
||||
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
|
||||
|
||||
@copyright (c) 2016 by Paul Scherrer Institut @n
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); @n
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import pmsco.database.access as db
|
||||
import pmsco.database.common as db_common
|
||||
import pmsco.database.ingest as db_ingest
|
||||
import pmsco.database.orm as db_orm
|
||||
import pmsco.database.query as db_query
|
||||
import pmsco.database.util as db_util
|
||||
from tests.database.test_common import setup_sample_database
|
||||
|
||||
|
||||
def pop_query_hook(query, gen):
|
||||
return query.filter(db_orm.Model.gen == gen)
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.db = db.DatabaseAccess()
|
||||
self.db.connect(":memory:")
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
# before any methods in this class
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
# teardown_class() after any methods in this class
|
||||
pass
|
||||
|
||||
def test_query_model_results_array(self):
|
||||
with self.db.session() as session:
|
||||
objs = setup_sample_database(session)
|
||||
job = objs['j1']
|
||||
|
||||
index = {'_scan': -1, '_domain': -1, '_emit': -1, '_region': -1}
|
||||
model2 = {'parA': 4.123, 'parB': 8.567, '_model': 92, '_rfac': 0.654, '_gen': 1, '_particle': 1, '_secs': 0.1}
|
||||
model3 = {'parA': 3.412, 'parB': 7.856, '_model': 93, '_rfac': 0.345, '_gen': 2, '_particle': 2, '_secs': 0.2}
|
||||
model4 = {'parA': 4.123, 'parB': 8.567, '_model': 94, '_rfac': 0.354, '_gen': 2, '_particle': 3, '_secs': 0.3}
|
||||
model5 = {'parA': 2.341, 'parC': 6.785, '_model': 95, '_rfac': 0.453}
|
||||
model6 = {'parA': 4.123, 'parB': 8.567, '_model': 96, '_rfac': 0.354, '_gen': 3, '_particle': 5, '_secs': 0.5}
|
||||
model2.update(index)
|
||||
model3.update(index)
|
||||
model4.update(index)
|
||||
model5.update(index)
|
||||
model6.update(index)
|
||||
m2, r2 = db_ingest.insert_result(session, job, model2, model2, model2)
|
||||
m3, r3 = db_ingest.insert_result(session, job, model3, model3, model3)
|
||||
m4, r4 = db_ingest.insert_result(session, job, model4, model4, model4)
|
||||
m5, r5 = db_ingest.insert_result(session, job, model5, model5, model5)
|
||||
m6, r6 = db_ingest.insert_result(session, job, model6, model6, model6)
|
||||
session.commit()
|
||||
|
||||
models = [m3, m4, m5]
|
||||
result_values, result_deltas = db_query.query_model_results_array(session, models=models, include_params=True)
|
||||
|
||||
template = ['parA', 'parB', 'parC', '_model', '_rfac', '_gen', '_particle', '_secs']
|
||||
dt = [(field, db_util.field_to_numpy_type(field)) for field in template]
|
||||
expected = np.zeros((len(models),), dtype=dt)
|
||||
expected['parA'] = np.array([3.412, 4.123, 2.341])
|
||||
expected['parB'] = np.array([7.856, 8.567, None])
|
||||
expected['parC'] = np.array([None, None, 6.785])
|
||||
expected['_model'] = np.array([93, 94, 95])
|
||||
expected['_rfac'] = np.array([0.345, 0.354, 0.453])
|
||||
expected['_gen'] = np.array([2, 2, 0])
|
||||
expected['_particle'] = np.array([2, 3, 0])
|
||||
expected['_secs'] = np.array([0.2, 0.3, None])
|
||||
|
||||
self.assertEqual(result_values.shape, expected.shape)
|
||||
np.testing.assert_array_almost_equal(result_values['parA'], expected['parA'])
|
||||
np.testing.assert_array_almost_equal(result_values['parB'], expected['parB'])
|
||||
np.testing.assert_array_almost_equal(result_values['parC'], expected['parC'])
|
||||
np.testing.assert_array_almost_equal(result_values['_model'], expected['_model'])
|
||||
np.testing.assert_array_almost_equal(result_values['_gen'], expected['_gen'])
|
||||
np.testing.assert_array_almost_equal(result_values['_particle'], expected['_particle'])
|
||||
np.testing.assert_array_almost_equal(result_values['_rfac'], expected['_rfac'])
|
||||
np.testing.assert_array_almost_equal(result_values['_secs'], expected['_secs'])
|
||||
|
||||
self.assertEqual(result_deltas.shape, expected.shape)
|
||||
np.testing.assert_array_almost_equal(result_deltas['parA'], expected['parA'])
|
||||
np.testing.assert_array_almost_equal(result_deltas['parB'], expected['parB'])
|
||||
np.testing.assert_array_almost_equal(result_deltas['parC'], expected['parC'])
|
||||
np.testing.assert_array_almost_equal(result_deltas['_model'], expected['_model'])
|
||||
np.testing.assert_array_almost_equal(result_deltas['_gen'], expected['_gen'])
|
||||
np.testing.assert_array_almost_equal(result_deltas['_particle'], expected['_particle'])
|
||||
|
||||
def test_query_model_results_array_index(self):
|
||||
with self.db.session() as session:
|
||||
objs = setup_sample_database(session)
|
||||
job = objs['j1']
|
||||
|
||||
model = {'parA': 4.123, 'parB': 8.567, 'parC': 6.785}
|
||||
|
||||
index1 = {'_model': 99, '_scan': -1, '_domain': -1, '_emit': -1, '_region': -1}
|
||||
index2 = {'_model': 99, '_scan': 1, '_domain': -1, '_emit': -1, '_region': -1}
|
||||
index3 = {'_model': 99, '_scan': 1, '_domain': 1, '_emit': -1, '_region': -1}
|
||||
index4 = {'_model': 99, '_scan': 1, '_domain': 1, '_emit': 1, '_region': -1}
|
||||
index5 = {'_model': 99, '_scan': 1, '_domain': 1, '_emit': 1, '_region': 1}
|
||||
|
||||
result1 = {'_rfac': 0.154, '_gen': 1, '_particle': 1}
|
||||
result1.update(model)
|
||||
result2 = {'_rfac': 0.254, '_gen': 1, '_particle': 1}
|
||||
result2.update(model)
|
||||
result3 = {'_rfac': 0.354, '_gen': 1, '_particle': 1}
|
||||
result3.update(model)
|
||||
result4 = {'_rfac': 0.454, '_gen': 1, '_particle': 1}
|
||||
result4.update(model)
|
||||
result5 = {'_rfac': 0.554, '_gen': 1, '_particle': 1}
|
||||
result5.update(model)
|
||||
|
||||
m1, r1 = db_ingest.insert_result(session, job, index1, result1, result1)
|
||||
m2, r2 = db_ingest.insert_result(session, job, index2, result2, result2)
|
||||
m3, r3 = db_ingest.insert_result(session, job, index3, result3, result3)
|
||||
m4, r4 = db_ingest.insert_result(session, job, index4, result4, result4)
|
||||
m5, r5 = db_ingest.insert_result(session, job, index5, result5, result5)
|
||||
session.commit()
|
||||
|
||||
self.assertEqual(m1.id, m2.id)
|
||||
self.assertEqual(m1.id, m3.id)
|
||||
self.assertEqual(m1.id, m4.id)
|
||||
self.assertEqual(m1.id, m5.id)
|
||||
|
||||
result_values, result_deltas = db_query.query_model_results_array(session,
|
||||
model=99, domain=1, include_params=True)
|
||||
|
||||
pars = ['parA', 'parB', 'parC']
|
||||
dt = [(k, 'f8') for k in pars]
|
||||
controls = ['_model', '_scan', '_domain', '_emit', '_region', '_rfac']
|
||||
dt.extend(((k, db_util.field_to_numpy_type(k)) for k in controls))
|
||||
expected = np.zeros((3,), dtype=dt)
|
||||
expected['parA'] = np.array([4.123, 4.123, 4.123])
|
||||
expected['parB'] = np.array([8.567, 8.567, 8.567])
|
||||
expected['parC'] = np.array([6.785, 6.785, 6.785])
|
||||
expected['_model'] = np.array([99, 99, 99])
|
||||
expected['_scan'] = np.array([1, 1, 1])
|
||||
expected['_domain'] = np.array([1, 1, 1])
|
||||
expected['_emit'] = np.array([-1, 1, 1])
|
||||
expected['_region'] = np.array([-1, -1, 1])
|
||||
expected['_rfac'] = np.array([0.354, 0.454, 0.554])
|
||||
|
||||
self.assertEqual(result_values.shape, expected.shape)
|
||||
np.testing.assert_array_almost_equal(result_values['parA'], expected['parA'])
|
||||
np.testing.assert_array_almost_equal(result_values['parB'], expected['parB'])
|
||||
np.testing.assert_array_almost_equal(result_values['parC'], expected['parC'])
|
||||
np.testing.assert_array_almost_equal(result_values['_model'], expected['_model'])
|
||||
np.testing.assert_array_almost_equal(result_values['_scan'], expected['_scan'])
|
||||
np.testing.assert_array_almost_equal(result_values['_domain'], expected['_domain'])
|
||||
np.testing.assert_array_almost_equal(result_values['_emit'], expected['_emit'])
|
||||
np.testing.assert_array_almost_equal(result_values['_region'], expected['_region'])
|
||||
np.testing.assert_array_almost_equal(result_values['_rfac'], expected['_rfac'])
|
||||
|
||||
def test_query_model_results_hook(self):
|
||||
with self.db.session() as session:
|
||||
objs = setup_sample_database(session)
|
||||
job = objs['j1']
|
||||
|
||||
index = {'_scan': -1, '_domain': -1, '_emit': -1, '_region': -1}
|
||||
model2 = {'parA': 4.123, 'parB': 8.567, '_model': 92, '_rfac': 0.654, '_gen': 1, '_particle': 1}
|
||||
model3 = {'parA': 3.412, 'parB': 7.856, '_model': 93, '_rfac': 0.345, '_gen': 2, '_particle': 2}
|
||||
model4 = {'parA': 4.123, 'parB': 8.567, '_model': 94, '_rfac': 0.354, '_gen': 2, '_particle': 3}
|
||||
model5 = {'parA': 2.341, 'parC': 6.785, '_model': 95, '_rfac': 0.453}
|
||||
model6 = {'parA': 4.123, 'parB': 8.567, '_model': 96, '_rfac': 0.354, '_gen': 3, '_particle': 5}
|
||||
model2.update(index)
|
||||
model3.update(index)
|
||||
model4.update(index)
|
||||
model5.update(index)
|
||||
model6.update(index)
|
||||
m2, r2 = db_ingest.insert_result(session, job, model2, model2, model2)
|
||||
m3, r3 = db_ingest.insert_result(session, job, model3, model3, model3)
|
||||
m4, r4 = db_ingest.insert_result(session, job, model4, model4, model4)
|
||||
m5, r5 = db_ingest.insert_result(session, job, model5, model5, model5)
|
||||
m6, r6 = db_ingest.insert_result(session, job, model6, model6, model6)
|
||||
session.commit()
|
||||
|
||||
models = [m3, m4]
|
||||
hd = {'gen': 2}
|
||||
result_values, result_deltas = db_query.query_model_results_array(session, include_params=True,
|
||||
query_hook=pop_query_hook, hook_data=hd)
|
||||
|
||||
template = ['parA', 'parB', 'parC', '_model', '_rfac', '_gen', '_particle']
|
||||
dt = [(field, db_util.field_to_numpy_type(field)) for field in template]
|
||||
|
||||
expected = np.zeros((len(models),), dtype=dt)
|
||||
expected['parA'] = np.array([3.412, 4.123])
|
||||
expected['parB'] = np.array([7.856, 8.567])
|
||||
expected['_model'] = np.array([93, 94])
|
||||
expected['_rfac'] = np.array([0.345, 0.354])
|
||||
expected['_gen'] = np.array([2, 2])
|
||||
expected['_particle'] = np.array([2, 3])
|
||||
|
||||
self.assertEqual(result_values.shape, expected.shape)
|
||||
self.assertNotIn('parC', result_values.dtype.names)
|
||||
np.testing.assert_array_almost_equal(result_values['parA'], expected['parA'])
|
||||
np.testing.assert_array_almost_equal(result_values['parB'], expected['parB'])
|
||||
np.testing.assert_array_almost_equal(result_values['_model'], expected['_model'])
|
||||
np.testing.assert_array_almost_equal(result_values['_gen'], expected['_gen'])
|
||||
np.testing.assert_array_almost_equal(result_values['_particle'], expected['_particle'])
|
||||
|
||||
def test_query_best_task_models(self):
|
||||
with self.db.session() as session:
|
||||
objs = setup_sample_database(session)
|
||||
job = objs['j1']
|
||||
|
||||
model0xxx = {'_model': 0, '_scan': -1, '_domain': -1, '_emit': -1, '_region': -1, 'parA': 4., 'parB': 8.567,
|
||||
'_rfac': 0.01}
|
||||
model00xx = {'_model': 1, '_scan': 0, '_domain': -1, '_emit': -1, '_region': -1, 'parA': 4., 'parB': 8.567,
|
||||
'_rfac': 0.02}
|
||||
model000x = {'_model': 2, '_scan': 0, '_domain': 0, '_emit': -1, '_region': -1, 'parA': 4., 'parB': 8.567,
|
||||
'_rfac': 0.03}
|
||||
model01xx = {'_model': 3, '_scan': 1, '_domain': -1, '_emit': -1, '_region': -1, 'parA': 4., 'parB': 8.567,
|
||||
'_rfac': 0.04}
|
||||
model010x = {'_model': 4, '_scan': 1, '_domain': 0, '_emit': -1, '_region': -1, 'parA': 4., 'parB': 8.567,
|
||||
'_rfac': 0.05}
|
||||
|
||||
model1xxx = {'_model': 5, '_scan': -1, '_domain': -1, '_emit': -1, '_region': -1, 'parA': 4.123,
|
||||
'parB': 8.567, '_rfac': 0.09}
|
||||
model10xx = {'_model': 6, '_scan': 0, '_domain': -1, '_emit': -1, '_region': -1, 'parA': 4.123,
|
||||
'parB': 8.567, '_rfac': 0.08}
|
||||
model100x = {'_model': 7, '_scan': 0, '_domain': 0, '_emit': -1, '_region': -1, 'parA': 4.123,
|
||||
'parB': 8.567, '_rfac': 0.07}
|
||||
model11xx = {'_model': 8, '_scan': 1, '_domain': -1, '_emit': -1, '_region': -1, 'parA': 4.123,
|
||||
'parB': 8.567, '_rfac': 0.06}
|
||||
model110x = {'_model': 9, '_scan': 1, '_domain': 0, '_emit': -1, '_region': -1, 'parA': 4.123,
|
||||
'parB': 8.567, '_rfac': 0.05}
|
||||
|
||||
model2xxx = {'_model': 10, '_scan': -1, '_domain': -1, '_emit': -1, '_region': -1, 'parA': 4.123,
|
||||
'parB': 8.567, '_rfac': 0.01}
|
||||
|
||||
db_ingest.insert_result(session, job, model0xxx, model0xxx)
|
||||
db_ingest.insert_result(session, job, model00xx, model00xx)
|
||||
db_ingest.insert_result(session, job, model000x, model000x)
|
||||
db_ingest.insert_result(session, job, model01xx, model01xx)
|
||||
db_ingest.insert_result(session, job, model010x, model010x)
|
||||
|
||||
db_ingest.insert_result(session, job, model1xxx, model1xxx)
|
||||
db_ingest.insert_result(session, job, model10xx, model10xx)
|
||||
db_ingest.insert_result(session, job, model100x, model100x)
|
||||
db_ingest.insert_result(session, job, model11xx, model11xx)
|
||||
db_ingest.insert_result(session, job, model110x, model110x)
|
||||
|
||||
db_ingest.insert_result(session, job, model2xxx, model2xxx)
|
||||
|
||||
result = db_query.query_best_task_models(session, job.id, level=1, count=2)
|
||||
|
||||
expected = {0, 1, 3, 6, 8, 10}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_query_best_models_per_job(self):
|
||||
with self.db.session() as session:
|
||||
objs = setup_sample_database(session)
|
||||
job = objs['j2']
|
||||
|
||||
model2 = {'parA': 4.123, 'parB': 8.567, '_model': 92, '_rfac': 0.654, '_gen': 1, '_particle': 2}
|
||||
model3 = {'parA': 3.412, 'parB': 7.856, '_model': 93, '_rfac': 0.345, '_gen': 1, '_particle': 3}
|
||||
model4 = {'parA': 4.123, 'parB': 8.567, '_model': 94, '_rfac': 0.354, '_gen': 1, '_particle': 4}
|
||||
model5 = {'parA': 2.341, 'parC': 6.785, '_model': 95, '_rfac': 0.453, '_gen': 1, '_particle': 5}
|
||||
model6 = {'parA': 4.123, 'parB': 8.567, '_model': 96, '_rfac': 0.354, '_gen': 1, '_particle': 6}
|
||||
model7 = {'parA': 5.123, 'parB': 6.567, '_model': 97, '_rfac': 0.154, '_gen': 1, '_particle': 7}
|
||||
|
||||
model2.update({'_scan': -1, '_domain': -1, '_emit': -1, '_region': -1})
|
||||
model3.update({'_scan': 1, '_domain': -1, '_emit': -1, '_region': -1})
|
||||
model4.update({'_scan': 2, '_domain': 11, '_emit': 23, '_region': 33})
|
||||
model5.update({'_scan': 3, '_domain': 11, '_emit': -1, '_region': -1})
|
||||
model6.update({'_scan': 4, '_domain': 11, '_emit': 25, '_region': -1})
|
||||
model7.update({'_scan': 5, '_domain': -1, '_emit': -1, '_region': -1})
|
||||
m2, r2 = db_ingest.insert_result(session, job, model2, model2)
|
||||
m3, r3 = db_ingest.insert_result(session, job, model3, model3)
|
||||
m4, r4 = db_ingest.insert_result(session, job, model4, model4)
|
||||
m5, r5 = db_ingest.insert_result(session, job, model5, model5)
|
||||
m6, r6 = db_ingest.insert_result(session, job, model6, model6)
|
||||
m7, r7 = db_ingest.insert_result(session, job, model7, model7)
|
||||
|
||||
lim = 3
|
||||
query = db_query.query_best_models_per_job(session, task_level='domain', limit=lim)
|
||||
expected_models = [91, 97]
|
||||
self.assertEqual(len(query), len(expected_models))
|
||||
for model, result in query:
|
||||
self.assertIn(model.model, expected_models)
|
||||
|
||||
lim = 3
|
||||
query = db_query.query_best_models_per_job(session, jobs=[job], task_level='domain', limit=lim)
|
||||
expected_models = [97]
|
||||
self.assertEqual(len(query), len(expected_models))
|
||||
for model, result in query:
|
||||
self.assertIn(model.model, expected_models)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
83
tests/database/test_util.py
Normal file
83
tests/database/test_util.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
@package tests.test_database
|
||||
unit tests for pmsco.database
|
||||
|
||||
the purpose of these tests is to help debugging the code.
|
||||
|
||||
to run the tests, change to the directory which contains the tests directory, and execute =nosetests=.
|
||||
|
||||
@pre nose must be installed (python-nose package on Debian).
|
||||
|
||||
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
|
||||
|
||||
@copyright (c) 2016 by Paul Scherrer Institut @n
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); @n
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
import pmsco.database.util as util
|
||||
import pmsco.dispatch as dispatch
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
# before any methods in this class
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
# teardown_class() after any methods in this class
|
||||
pass
|
||||
|
||||
def test_regular_params(self):
|
||||
d1 = {'parA': 1.234, 'par_B': 5.678, '_model': 91, '_rfac': 0.534}
|
||||
d2 = util.regular_params(d1)
|
||||
d3 = {'parA': d1['parA'], 'par_B': d1['par_B']}
|
||||
self.assertEqual(d2, d3)
|
||||
self.assertIsNot(d2, d1)
|
||||
|
||||
def test_special_params(self):
|
||||
d1 = {'parA': 1.234, 'par_B': 5.678, '_model': 91, '_rfac': 0.534, '_db_model_id': 99}
|
||||
d2 = util.special_params(d1)
|
||||
d3 = {'model': d1['_model'], 'rfac': d1['_rfac']}
|
||||
self.assertEqual(d2, d3)
|
||||
self.assertIsNot(d2, d1)
|
||||
|
||||
dt = [('parA', 'f4'), ('par_B', 'f4'), ('_model', 'i4'), ('_rfac', 'f4'), ('_db_model_id', 'f4')]
|
||||
arr = np.zeros(1, dtype=dt)
|
||||
for k, v in d1.items():
|
||||
arr[0][k] = v
|
||||
d4 = util.special_params(arr[0])
|
||||
self.assertEqual(d4.keys(), d3.keys())
|
||||
for k in d4:
|
||||
self.assertAlmostEqual(d4[k], d3[k])
|
||||
|
||||
cid1 = dispatch.CalcID(1, 2, 3, 4, -1)
|
||||
cid2 = util.special_params(cid1)
|
||||
cid3 = {'model': 1, 'scan': 2, 'domain': 3, 'emit': 4, 'region': -1}
|
||||
self.assertEqual(cid2, cid3)
|
||||
|
||||
l1 = d1.keys()
|
||||
l2 = util.special_params(l1)
|
||||
l3 = d3.keys()
|
||||
self.assertEqual(list(l2), list(l3))
|
||||
|
||||
t1 = tuple(l1)
|
||||
t2 = util.special_params(t1)
|
||||
t3 = tuple(l3)
|
||||
self.assertEqual(t2, t3)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
71
tests/reports/test_results.py
Normal file
71
tests/reports/test_results.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import datetime
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import unittest
|
||||
import pmsco.database.access as db_access
|
||||
import pmsco.database.orm as db_orm
|
||||
import pmsco.reports.results as rp_results
|
||||
import pmsco.dispatch as dispatch
|
||||
|
||||
|
||||
def setup_sample_database(session):
|
||||
p1 = db_orm.Project(name="oldproject", code="oldcode")
|
||||
p2 = db_orm.Project(name="unittest", code="testcode")
|
||||
j1 = db_orm.Job(project=p1, name="oldjob", mode="oldmode", machine="oldhost", datetime=datetime.datetime.now())
|
||||
j2 = db_orm.Job(project=p2, name="testjob", mode="testmode", machine="testhost", datetime=datetime.datetime.now())
|
||||
pk1 = db_orm.Param(key='parA')
|
||||
pk2 = db_orm.Param(key='parB')
|
||||
pk3 = db_orm.Param(key='parC')
|
||||
m1 = db_orm.Model(job=j1, model=91)
|
||||
m2 = db_orm.Model(job=j2, model=92)
|
||||
r1 = db_orm.Result(calc_id=dispatch.CalcID(91, -1, -1, -1, -1), rfac=0.534, secs=37.9)
|
||||
r1.model = m1
|
||||
pv1 = db_orm.ParamValue(model=m1, param=pk1, value=1.234, delta=0.1234)
|
||||
pv2 = db_orm.ParamValue(model=m1, param=pk2, value=5.678, delta=-0.5678)
|
||||
pv3 = db_orm.ParamValue(model=m2, param=pk3, value=6.785, delta=0.6785)
|
||||
objects = {'p1': p1, 'p2': p2, 'j1': j1, 'j2': j2, 'm1': m1, 'm2': m2, 'r1': r1,
|
||||
'pv1': pv1, 'pv2': pv2, 'pv3': pv3, 'pk1': pk1, 'pk2': pk2, 'pk3': pk3}
|
||||
session.add_all(objects.values())
|
||||
session.commit()
|
||||
return objects
|
||||
|
||||
|
||||
class TestResultsMethods(unittest.TestCase):
|
||||
def test_array_range(self):
|
||||
dtype = [('A', 'f8'), ('B', 'f8'), ('C', 'f8')]
|
||||
data = np.array([(1.5, 3.5, 3.5),
|
||||
(1.6, 2.6, 2.6),
|
||||
(1.7, 2.7, 3.7),
|
||||
(1.8, 2.8, 3.8)], dtype=dtype)
|
||||
exp_rmin = {'A': 1.5, 'B': 2.6, 'C': 2.6}
|
||||
exp_rmax = {'A': 1.8, 'B': 3.5, 'C': 3.8}
|
||||
rmin, rmax = rp_results.array_range(data)
|
||||
self.assertEqual(exp_rmin, rmin)
|
||||
self.assertEqual(exp_rmax, rmax)
|
||||
|
||||
|
||||
class TestResultData(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.db = db_access.DatabaseAccess()
|
||||
self.db.connect(":memory:")
|
||||
|
||||
def test_update_collections(self):
|
||||
data_dir = Path(__file__).parent.parent
|
||||
data_file = data_dir / "test_swarm.setup_with_results.1.dat"
|
||||
raw_values = np.atleast_1d(np.genfromtxt(data_file, names=True))
|
||||
|
||||
rd = rp_results.ResultData()
|
||||
rd.values = raw_values
|
||||
rd.update_collections()
|
||||
np.testing.assert_array_equal(rd.generations, np.array((1, 2, 3, 4)))
|
||||
|
||||
def test_load_from_db(self):
|
||||
with self.db.session() as session:
|
||||
setup_sample_database(session)
|
||||
rd = rp_results.ResultData()
|
||||
rd.levels = {'scan': -1}
|
||||
rd.load_from_db(session)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
69
tests/reports/test_rfactor.py
Normal file
69
tests/reports/test_rfactor.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import numpy as np
|
||||
import unittest
|
||||
import pmsco.reports.rfactor as rp_rfactor
|
||||
|
||||
|
||||
class TestGridMethods(unittest.TestCase):
|
||||
def test_triplet_to_grid__basic(self):
|
||||
x = np.array([-1, 0, 1, 1, 0, -1])
|
||||
y = np.array([2, 2, 2, 3, 3, 3])
|
||||
z = np.array([0.1, 0.2, 0.3, 0.6, 0.5, 0.4])
|
||||
gx, gy, gz = rp_rfactor.triplet_to_grid(x, y, z)
|
||||
expected_gx = np.array([[-1, 0, 1], [-1, 0, 1]]).T
|
||||
expected_gy = np.array([[2, 2, 2], [3, 3, 3]]).T
|
||||
expected_gz = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]).T
|
||||
np.testing.assert_array_almost_equal(gx, expected_gx, 1, "grid_x")
|
||||
np.testing.assert_array_almost_equal(gy, expected_gy, 1, "grid_y")
|
||||
np.testing.assert_array_almost_equal(gz, expected_gz, 2, "grid_z")
|
||||
|
||||
def test_triplet_to_grid__imprecise(self):
|
||||
x = np.array([-0.99, 0, 1, 1.001, 0, -1])
|
||||
y = np.array([1.999, 2.00001, 2, 3.01, 2.98, 3])
|
||||
z = np.array([0.1, 0.2, 0.3, 0.6, 0.5, 0.4])
|
||||
gx, gy, gz = rp_rfactor.triplet_to_grid(x, y, z)
|
||||
expected_gx = np.array([[-1, 0, 1], [-1, 0, 1]]).T
|
||||
expected_gy = np.array([[2, 2, 2], [3, 3, 3]]).T
|
||||
expected_gz = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]).T
|
||||
np.testing.assert_array_almost_equal(gx, expected_gx, 1, "grid_x")
|
||||
np.testing.assert_array_almost_equal(gy, expected_gy, 1, "grid_y")
|
||||
np.testing.assert_array_almost_equal(gz, expected_gz, 2, "grid_z")
|
||||
|
||||
def test_triplet_to_grid__missing(self):
|
||||
x = np.array([-1, 0, 1, 0, -1])
|
||||
y = np.array([2, 2, 3, 3, 3])
|
||||
z = np.array([0.1, 0.2, 0.6, 0.5, 0.4])
|
||||
gx, gy, gz = rp_rfactor.triplet_to_grid(x, y, z)
|
||||
expected_gx = np.array([[-1, 0, 1], [-1, 0, 1]]).T
|
||||
expected_gy = np.array([[2, 2, 2], [3, 3, 3]]).T
|
||||
expected_gz = np.array([[0.1, 0.2, 0.2], [0.4, 0.5, 0.6]]).T
|
||||
np.testing.assert_array_almost_equal(gx, expected_gx, 1, "grid_x")
|
||||
np.testing.assert_array_almost_equal(gy, expected_gy, 1, "grid_y")
|
||||
np.testing.assert_array_almost_equal(gz, expected_gz, 2, "grid_z")
|
||||
|
||||
def test_triplet_to_grid__extra(self):
|
||||
x = np.array([-1, 0, 1, 1, 1, 0, -1])
|
||||
y = np.array([2, 2, 2, 2.01, 3, 3, 3])
|
||||
z = np.array([0.1, 0.2, 0.3, 0.35, 0.6, 0.5, 0.4])
|
||||
gx, gy, gz = rp_rfactor.triplet_to_grid(x, y, z)
|
||||
expected_gx = np.array([[-1, 0, 1], [-1, 0, 1]]).T
|
||||
expected_gy = np.array([[2, 2, 2], [3, 3, 3]]).T
|
||||
expected_gz = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]).T
|
||||
np.testing.assert_array_almost_equal(gx, expected_gx, 1, "grid_x")
|
||||
np.testing.assert_array_almost_equal(gy, expected_gy, 1, "grid_y")
|
||||
np.testing.assert_array_almost_equal(gz, expected_gz, 2, "grid_z")
|
||||
|
||||
def test_triplet_to_grid__split_column(self):
|
||||
x = np.array([-1, 0, 0.5, 1, 1, 0.5, 0, -1])
|
||||
y = np.array([2, 2, 2, 2, 3, 3, 3, 3])
|
||||
z = np.array([0.1, 0.2, 0.24, 0.3, 0.6, 0.45, 0.5, 0.4])
|
||||
gx, gy, gz = rp_rfactor.triplet_to_grid(x, y, z)
|
||||
expected_gx = np.array([[-1, -0.5, 0, 0.5, 1], [-1, -0.5, 0, 0.5, 1]]).T
|
||||
expected_gy = np.array([[2, 2, 2, 2, 2], [3, 3, 3, 3, 3]]).T
|
||||
expected_gz = np.array([[0.1, 0.1, 0.2, 0.24, 0.3], [0.4, 0.5, 0.5, 0.45, 0.6]]).T
|
||||
np.testing.assert_array_almost_equal(gx, expected_gx, 1, "grid_x")
|
||||
np.testing.assert_array_almost_equal(gy, expected_gy, 1, "grid_y")
|
||||
np.testing.assert_array_almost_equal(gz, expected_gz, 2, "grid_z")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -459,9 +459,9 @@ class TestClusterFunctions(unittest.TestCase):
|
||||
line = f.readline()
|
||||
self.assertEqual(line, b"# index element symbol class x y z emitter charge\n", b"line 1: " + line)
|
||||
line = f.readline()
|
||||
self.assertRegexpMatches(line, b"[0-9]+ +1 +H +[0-9]+ +[0.]+ +[0.]+ +[0.]+ +1 +[0.]", b"line 3: " + line)
|
||||
self.assertRegex(line, b"[0-9]+ +1 +H +[0-9]+ +[0.]+ +[0.]+ +[0.]+ +1 +[0.]", b"line 3: " + line)
|
||||
line = f.readline()
|
||||
self.assertRegexpMatches(line, b"[0-9]+ +14 +Si +[0-9]+ +[01.-]+ +[01.-]+ +[0.]+ +1 +[0.]", b"line 4: " + line)
|
||||
self.assertRegex(line, b"[0-9]+ +14 +Si +[0-9]+ +[01.-]+ +[01.-]+ +[0.]+ +1 +[0.]", b"line 4: " + line)
|
||||
line = f.readline()
|
||||
self.assertEqual(b"", line, b"end of file")
|
||||
|
||||
@@ -473,9 +473,9 @@ class TestClusterFunctions(unittest.TestCase):
|
||||
line = f.readline()
|
||||
self.assertEqual(b"qwerty\n", line, b"line 2: " + line)
|
||||
line = f.readline()
|
||||
self.assertRegexpMatches(line, b"H +[0.]+ +[0.]+ +[0.]+", b"line 3: " + line)
|
||||
self.assertRegex(line, b"H +[0.]+ +[0.]+ +[0.]+", b"line 3: " + line)
|
||||
line = f.readline()
|
||||
self.assertRegexpMatches(line, b"Si +[01.-]+ +[01.-]+ +[0.]+", b"line 4: " + line)
|
||||
self.assertRegex(line, b"Si +[01.-]+ +[01.-]+ +[0.]+", b"line 4: " + line)
|
||||
line = f.readline()
|
||||
self.assertEqual(b"", line, b"end of file")
|
||||
|
||||
|
||||
@@ -20,12 +20,11 @@ Licensed under the Apache License, Version 2.0 (the "License"); @n
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import unittest
|
||||
import math
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
import unittest
|
||||
|
||||
import pmsco.data as md
|
||||
|
||||
|
||||
@@ -140,6 +139,50 @@ class TestDataFunctions(unittest.TestCase):
|
||||
for dim in expected_positions:
|
||||
np.testing.assert_almost_equal(scan_positions[dim], expected_positions[dim], decimal=3)
|
||||
|
||||
def test_holo_array(self):
|
||||
args = {
|
||||
"theta_start": 90,
|
||||
"theta_step": 1,
|
||||
"theta_range": 90,
|
||||
"phi_start": 0,
|
||||
"phi_range": 360,
|
||||
"phi_refinement": 1
|
||||
}
|
||||
|
||||
result = md.holo_array(generator=md.holo_grid, generator_args=args, datatype="ETPI")
|
||||
result['e'] = 250.
|
||||
md.sort_data(result)
|
||||
|
||||
ref_path = Path(__file__).parent.parent / "pmsco" / "projects" / "twoatom" / "twoatom_hemi_250e.etpi"
|
||||
ref_array = md.load_data(ref_path)
|
||||
np.testing.assert_array_almost_equal(result['t'], ref_array['t'], decimal=1)
|
||||
np.testing.assert_array_almost_equal(result['p'], ref_array['p'], decimal=1)
|
||||
|
||||
def test_analyse_holoscan_steps(self):
|
||||
args = {
|
||||
"theta_start": 90.0,
|
||||
"theta_step": 2.0,
|
||||
"theta_range": 90.0,
|
||||
"phi_start": 0.0,
|
||||
"phi_range": 120.0,
|
||||
"phi_refinement": 1.0
|
||||
}
|
||||
|
||||
holo = md.holo_array(generator=md.holo_grid, generator_args=args, datatype="TPI")
|
||||
theta, dtheta, dphi = md.analyse_holoscan_steps(holo)
|
||||
|
||||
expected_theta = np.arange(args["theta_start"] - args["theta_range"],
|
||||
args["theta_start"] + args["theta_step"],
|
||||
args["theta_step"])
|
||||
expected_dtheta = np.ones_like(expected_theta) * args["theta_step"]
|
||||
|
||||
np.testing.assert_almost_equal(theta, expected_theta)
|
||||
np.testing.assert_almost_equal(dtheta, expected_dtheta)
|
||||
self.assertEqual(expected_theta.shape, dphi.shape)
|
||||
self.assertEqual(args["phi_range"], dphi[0])
|
||||
self.assertEqual(args["theta_step"], dphi[-1])
|
||||
np.testing.assert_array_less(np.ones_like(expected_theta) * args["theta_step"] * 0.999, dphi)
|
||||
|
||||
def test_calc_modfunc_mean_1d(self):
|
||||
modf = md.calc_modfunc_mean(self.e_scan)
|
||||
|
||||
@@ -176,6 +219,11 @@ class TestDataFunctions(unittest.TestCase):
|
||||
"""
|
||||
check that the result of msc_data.calc_modfunc_loess() is between -1 and 1.
|
||||
"""
|
||||
|
||||
# loess package not available
|
||||
if md.loess is None:
|
||||
return
|
||||
|
||||
modf = md.calc_modfunc_loess(self.e_scan)
|
||||
self.assertEqual(self.e_scan.shape, modf.shape)
|
||||
exp_modf = self.e_scan.copy()
|
||||
@@ -190,6 +238,11 @@ class TestDataFunctions(unittest.TestCase):
|
||||
"""
|
||||
check that data.calc_modfunc_loess() ignores NaNs gracefully.
|
||||
"""
|
||||
|
||||
# loess package not available
|
||||
if md.loess is None:
|
||||
return
|
||||
|
||||
modified_index = 2
|
||||
self.e_scan['i'][modified_index] = np.nan
|
||||
modf = md.calc_modfunc_loess(self.e_scan)
|
||||
@@ -207,6 +260,11 @@ class TestDataFunctions(unittest.TestCase):
|
||||
"""
|
||||
check that the msc_data.calc_modfunc_loess() function does approximately what we want for a two-dimensional dataset.
|
||||
"""
|
||||
|
||||
# loess package not available
|
||||
if md.loess is None:
|
||||
return
|
||||
|
||||
n_e = 10
|
||||
n_a = 15
|
||||
shape = (n_e * n_a, )
|
||||
@@ -236,8 +294,7 @@ class TestDataFunctions(unittest.TestCase):
|
||||
# this is rough estimate of the result, manually optimized by trial and error in Igor.
|
||||
# the R factor should be sensitive enough to detect mixed-up axes.
|
||||
exp_modf['i'] = 0.03 * np.sin((scan['e'] - 150) / 50 * math.pi)
|
||||
rf = md.rfactor(modf, exp_modf)
|
||||
print(rf)
|
||||
rf = md.square_diff_rfactor(modf, exp_modf)
|
||||
self.assertLessEqual(rf, 0.50)
|
||||
|
||||
def test_alpha_mirror_average(self):
|
||||
@@ -377,11 +434,11 @@ class TestDataFunctions(unittest.TestCase):
|
||||
|
||||
weights = np.ones_like(exp_modf['i'])
|
||||
|
||||
r = md.scaled_rfactor(1.4, exp_modf['i'], weights, calc_modf['i'])
|
||||
r = md.scaled_rfactor_func(1.4, exp_modf['i'], weights, calc_modf['i'])
|
||||
self.assertGreater(r, 0.0)
|
||||
self.assertLess(r, 0.05)
|
||||
|
||||
def test_rfactor(self):
|
||||
def test_square_diff_rfactor(self):
|
||||
n = 20
|
||||
calc_modf = md.create_data((n,), dtype=md.DTYPE_ETPI)
|
||||
calc_modf['e'] = 0.0
|
||||
@@ -393,18 +450,18 @@ class TestDataFunctions(unittest.TestCase):
|
||||
exp_modf['i'] = 0.6 * np.sin(exp_modf['p'] + np.pi / 100.0)
|
||||
exp_modf['s'] = np.sqrt(np.abs(exp_modf['i']))
|
||||
|
||||
r1 = md.rfactor(exp_modf, calc_modf)
|
||||
r1 = md.square_diff_rfactor(exp_modf, calc_modf)
|
||||
self.assertAlmostEqual(r1, 0.95, delta=0.02)
|
||||
|
||||
# one nan should not make a big difference
|
||||
calc_modf['i'][3] = np.nan
|
||||
r2 = md.rfactor(exp_modf, calc_modf)
|
||||
r2 = md.square_diff_rfactor(exp_modf, calc_modf)
|
||||
self.assertAlmostEqual(r1, r2, delta=0.02)
|
||||
|
||||
# all values nan should raise an exception
|
||||
with self.assertRaises(ValueError):
|
||||
calc_modf['i'] = np.nan
|
||||
md.rfactor(exp_modf, calc_modf)
|
||||
md.square_diff_rfactor(exp_modf, calc_modf)
|
||||
|
||||
def test_optimize_rfactor(self):
|
||||
n = 20
|
||||
|
||||
@@ -1,570 +0,0 @@
|
||||
"""
|
||||
@package tests.test_database
|
||||
unit tests for pmsco.database
|
||||
|
||||
the purpose of these tests is to help debugging the code.
|
||||
|
||||
to run the tests, change to the directory which contains the tests directory, and execute =nosetests=.
|
||||
|
||||
@pre nose must be installed (python-nose package on Debian).
|
||||
|
||||
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
|
||||
|
||||
@copyright (c) 2016 by Paul Scherrer Institut @n
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); @n
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import unittest
|
||||
import datetime
|
||||
import os.path
|
||||
import tempfile
|
||||
import shutil
|
||||
import numpy as np
|
||||
import pmsco.database as db
|
||||
import pmsco.dispatch as dispatch
|
||||
import pmsco.optimizers.population as population
|
||||
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
self.lock_filename = os.path.join(self.test_dir, "test_database.lock")
|
||||
self.db = db.ResultsDatabase()
|
||||
self.db.connect(":memory:", lock_filename=self.lock_filename)
|
||||
|
||||
def tearDown(self):
|
||||
self.db.disconnect()
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
# before any methods in this class
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
# teardown_class() after any methods in this class
|
||||
pass
|
||||
|
||||
def test_regular_params(self):
|
||||
d1 = {'parA': 1.234, 'par_B': 5.678, '_model': 91, '_rfac': 0.534}
|
||||
d2 = db.regular_params(d1)
|
||||
d3 = {'parA': d1['parA'], 'par_B': d1['par_B']}
|
||||
self.assertEqual(d2, d3)
|
||||
self.assertIsNot(d2, d1)
|
||||
|
||||
def test_special_params(self):
|
||||
d1 = {'parA': 1.234, 'par_B': 5.678, '_model': 91, '_rfac': 0.534, '_db_model': 99}
|
||||
d2 = db.special_params(d1)
|
||||
d3 = {'model': d1['_model'], 'rfac': d1['_rfac']}
|
||||
self.assertEqual(d2, d3)
|
||||
self.assertIsNot(d2, d1)
|
||||
|
||||
dt = [('parA', 'f4'), ('par_B', 'f4'), ('_model', 'i4'), ('_rfac', 'f4'), ('_db_model', 'f4')]
|
||||
arr = np.zeros(1, dtype=dt)
|
||||
for k, v in d1.items():
|
||||
arr[0][k] = v
|
||||
d4 = db.special_params(arr[0])
|
||||
self.assertEqual(d4.keys(), d3.keys())
|
||||
for k in d4:
|
||||
self.assertAlmostEqual(d4[k], d3[k])
|
||||
|
||||
cid1 = dispatch.CalcID(1, 2, 3, 4, -1)
|
||||
cid2 = db.special_params(cid1)
|
||||
cid3 = {'model': 1, 'scan': 2, 'domain': 3, 'emit': 4, 'region': -1}
|
||||
self.assertEqual(cid2, cid3)
|
||||
|
||||
l1 = d1.keys()
|
||||
l2 = db.special_params(l1)
|
||||
l3 = d3.keys()
|
||||
self.assertEqual(list(l2), list(l3))
|
||||
|
||||
t1 = tuple(l1)
|
||||
t2 = db.special_params(t1)
|
||||
t3 = tuple(l3)
|
||||
self.assertEqual(t2, t3)
|
||||
|
||||
def setup_sample_database(self):
|
||||
self.db.register_project("oldproject", "oldcode")
|
||||
self.db.register_project("unittest", "testcode")
|
||||
self.db.register_job(self.db.project_id, "testjob", "testmode", "testhost", None, datetime.datetime.now())
|
||||
self.ex_model = {'parA': 1.234, 'parB': 5.678, '_model': 91, '_rfac': 0.534}
|
||||
self.db.register_params(self.ex_model)
|
||||
self.db.insert_model(self.ex_model)
|
||||
self.db.create_models_view()
|
||||
|
||||
def test_register_project(self):
|
||||
id1 = self.db.register_project("unittest1", "Atest")
|
||||
self.assertIsInstance(id1, int)
|
||||
self.assertEqual(id1, self.db.project_id)
|
||||
id2 = self.db.register_project("unittest2", "Btest")
|
||||
self.assertIsInstance(id2, int)
|
||||
self.assertEqual(id2, self.db.project_id)
|
||||
id3 = self.db.register_project("unittest1", "Ctest")
|
||||
self.assertIsInstance(id3, int)
|
||||
self.assertEqual(id3, self.db.project_id)
|
||||
self.assertNotEqual(id1, id2)
|
||||
self.assertEqual(id1, id3)
|
||||
c = self.db._conn.cursor()
|
||||
c.execute("select count(*) from Projects")
|
||||
count = c.fetchone()
|
||||
self.assertEqual(count[0], 2)
|
||||
c.execute("select name, code from Projects where id=:id", {'id': id1})
|
||||
row = c.fetchone()
|
||||
self.assertIsNotNone(row)
|
||||
self.assertEqual(len(row), 2)
|
||||
self.assertEqual(row[0], "unittest1")
|
||||
self.assertEqual(row[1], "Atest")
|
||||
self.assertEqual(row['name'], "unittest1")
|
||||
self.assertEqual(row['code'], "Atest")
|
||||
|
||||
def test_register_job(self):
|
||||
pid1 = self.db.register_project("unittest1", "Acode")
|
||||
pid2 = self.db.register_project("unittest2", "Bcode")
|
||||
dt1 = datetime.datetime.now()
|
||||
|
||||
# insert new job
|
||||
id1 = self.db.register_job(pid1, "Ajob", "Amode", "local", "Ahash", dt1, "Adesc")
|
||||
self.assertIsInstance(id1, int)
|
||||
self.assertEqual(id1, self.db.job_id)
|
||||
# insert another job
|
||||
id2 = self.db.register_job(pid1, "Bjob", "Amode", "local", "Ahash", dt1, "Adesc")
|
||||
self.assertIsInstance(id2, int)
|
||||
self.assertEqual(id2, self.db.job_id)
|
||||
# update first job
|
||||
id3 = self.db.register_job(pid1, "Ajob", "Cmode", "local", "Chash", dt1, "Cdesc")
|
||||
self.assertIsInstance(id3, int)
|
||||
self.assertEqual(id3, self.db.job_id)
|
||||
# insert another job with same name but in other project
|
||||
id4 = self.db.register_job(pid2, "Ajob", "Dmode", "local", "Dhash", dt1, "Ddesc")
|
||||
self.assertIsInstance(id4, int)
|
||||
self.assertEqual(id4, self.db.job_id)
|
||||
|
||||
self.assertNotEqual(id1, id2)
|
||||
self.assertEqual(id1, id3)
|
||||
self.assertNotEqual(id1, id4)
|
||||
|
||||
c = self.db._conn.cursor()
|
||||
c.execute("select count(*) from Jobs")
|
||||
count = c.fetchone()
|
||||
self.assertEqual(count[0], 3)
|
||||
c.execute("select name, mode, machine, git_hash, datetime, description from Jobs where id=:id", {'id': id1})
|
||||
row = c.fetchone()
|
||||
self.assertIsNotNone(row)
|
||||
self.assertEqual(len(row), 6)
|
||||
self.assertEqual(row[0], "Ajob")
|
||||
self.assertEqual(row[1], "Amode")
|
||||
self.assertEqual(row['machine'], "local")
|
||||
self.assertEqual(str(row['datetime']), str(dt1))
|
||||
self.assertEqual(row['git_hash'], "Ahash")
|
||||
self.assertEqual(row['description'], "Adesc")
|
||||
|
||||
def test_register_params(self):
|
||||
self.setup_sample_database()
|
||||
model5 = {'parA': 2.341, 'parC': 6.785, '_model': 92, '_rfac': 0.453}
|
||||
self.db.register_params(model5)
|
||||
expected = ['parA', 'parB', 'parC']
|
||||
|
||||
c = self.db._conn.cursor()
|
||||
c.execute("select * from Params order by key")
|
||||
results = c.fetchall()
|
||||
self.assertEqual(len(results), 3)
|
||||
result_params = [row['key'] for row in results]
|
||||
self.assertEqual(result_params, expected)
|
||||
|
||||
def test_query_project_params(self):
|
||||
self.setup_sample_database()
|
||||
project1 = self.db.project_id
|
||||
self.db.register_project("unittest2", "testcode2")
|
||||
self.db.register_job(self.db.project_id, "testjob2", "test", "localhost", None, datetime.datetime.now())
|
||||
model5 = {'parA': 2.341, 'parC': 6.785, '_model': 92, '_rfac': 0.453}
|
||||
self.db.register_params(model5)
|
||||
self.db.insert_model(model5)
|
||||
results = self.db.query_project_params(project_id=project1)
|
||||
expected = ['parA', 'parB']
|
||||
self.assertEqual(expected, sorted(list(results.keys())))
|
||||
|
||||
def test_insert_model(self):
|
||||
self.setup_sample_database()
|
||||
c = self.db._conn.cursor()
|
||||
c.execute("select count(*) from Models")
|
||||
count = c.fetchone()
|
||||
self.assertEqual(count[0], 1)
|
||||
c.execute("select * from Models")
|
||||
row = c.fetchone()
|
||||
model_id = row['id']
|
||||
self.assertIsInstance(model_id, int)
|
||||
self.assertEqual(row['job_id'], self.db.job_id)
|
||||
self.assertEqual(row['model'], self.ex_model['_model'])
|
||||
self.assertIsNone(row['gen'])
|
||||
self.assertIsNone(row['particle'])
|
||||
sql = "select key, value from ParamValues " + \
|
||||
"join Params on ParamValues.param_id = Params.id " + \
|
||||
"where model_id = :model_id"
|
||||
c.execute(sql, {'model_id': model_id})
|
||||
result = c.fetchall() # list of Row objects
|
||||
self.assertEqual(len(result), 2)
|
||||
for row in result:
|
||||
self.assertAlmostEqual(row['value'], self.ex_model[row['key']])
|
||||
|
||||
def test_query_model(self):
|
||||
self.setup_sample_database()
|
||||
c = self.db._conn.cursor()
|
||||
c.execute("select * from Models")
|
||||
row = c.fetchone()
|
||||
model_id = row['id']
|
||||
model = self.db.query_model(model_id)
|
||||
del self.ex_model['_model']
|
||||
del self.ex_model['_rfac']
|
||||
self.assertEqual(model, self.ex_model)
|
||||
|
||||
def test_query_model_array(self):
|
||||
self.setup_sample_database()
|
||||
index = {'_scan': -1, '_domain': -1, '_emit': -1, '_region': -1}
|
||||
model2 = {'parA': 4.123, 'parB': 8.567, '_model': 92, '_rfac': 0.654}
|
||||
model3 = {'parA': 3.412, 'parB': 7.856, '_model': 93, '_rfac': 0.345}
|
||||
model4 = {'parA': 4.123, 'parB': 8.567, '_model': 94, '_rfac': 0.354}
|
||||
model5 = {'parA': 2.341, 'parC': 6.785, '_model': 95, '_rfac': 0.453}
|
||||
model6 = {'parA': 4.123, 'parB': 8.567, '_model': 96, '_rfac': 0.354}
|
||||
self.db.register_params(model5)
|
||||
self.db.create_models_view()
|
||||
model2.update(index)
|
||||
model3.update(index)
|
||||
model4.update(index)
|
||||
model5.update(index)
|
||||
model6.update(index)
|
||||
self.db.insert_result(model2, model2)
|
||||
self.db.insert_result(model3, model3)
|
||||
self.db.insert_result(model4, model4)
|
||||
self.db.insert_result(model5, model5)
|
||||
self.db.insert_result(model6, model6)
|
||||
|
||||
# only model3, model4 and model5 fulfill all conditions and limits
|
||||
fil = ['mode = "testmode"', 'rfac <= 0.6']
|
||||
lim = 3
|
||||
result = self.db.query_model_array(filter=fil, limit=lim)
|
||||
|
||||
template = ['parA', 'parB', 'parC', '_model', '_rfac', '_gen', '_particle']
|
||||
dt = population.Population.get_pop_dtype(template)
|
||||
expected = np.zeros((lim,), dtype=dt)
|
||||
expected['parA'] = np.array([3.412, 4.123, 2.341])
|
||||
expected['parB'] = np.array([7.856, 8.567, None])
|
||||
expected['parC'] = np.array([None, None, 6.785])
|
||||
expected['_model'] = np.array([93, 94, 95])
|
||||
expected['_rfac'] = np.array([0.345, 0.354, 0.453])
|
||||
expected['_gen'] = np.array([0, 0, 0])
|
||||
expected['_particle'] = np.array([0, 0, 0])
|
||||
|
||||
self.assertEqual(result.shape, expected.shape)
|
||||
np.testing.assert_array_almost_equal(result['parA'], expected['parA'])
|
||||
np.testing.assert_array_almost_equal(result['parB'], expected['parB'])
|
||||
np.testing.assert_array_almost_equal(result['parC'], expected['parC'])
|
||||
np.testing.assert_array_almost_equal(result['_model'], expected['_model'])
|
||||
np.testing.assert_array_almost_equal(result['_gen'], expected['_gen'])
|
||||
np.testing.assert_array_almost_equal(result['_particle'], expected['_particle'])
|
||||
|
||||
def test_query_best_results(self):
|
||||
self.setup_sample_database()
|
||||
model2 = {'parA': 4.123, 'parB': 8.567, '_model': 92, '_rfac': 0.654, '_gen': 1, '_particle': 2}
|
||||
model3 = {'parA': 3.412, 'parB': 7.856, '_model': 93, '_rfac': 0.345, '_gen': 1, '_particle': 3}
|
||||
model4 = {'parA': 4.123, 'parB': 8.567, '_model': 94, '_rfac': 0.354, '_gen': 1, '_particle': 4}
|
||||
model5 = {'parA': 2.341, 'parC': 6.785, '_model': 95, '_rfac': 0.453, '_gen': 1, '_particle': 5}
|
||||
model6 = {'parA': 4.123, 'parB': 8.567, '_model': 96, '_rfac': 0.354, '_gen': 1, '_particle': 6}
|
||||
model7 = {'parA': 5.123, 'parB': 6.567, '_model': 97, '_rfac': 0.154, '_gen': 1, '_particle': 7}
|
||||
self.db.register_params(model5)
|
||||
self.db.create_models_view()
|
||||
model2.update({'_scan': -1, '_domain': 11, '_emit': 21, '_region': 31})
|
||||
model3.update({'_scan': 1, '_domain': 12, '_emit': 22, '_region': 32})
|
||||
model4.update({'_scan': 2, '_domain': 11, '_emit': 23, '_region': 33})
|
||||
model5.update({'_scan': 3, '_domain': 11, '_emit': 24, '_region': 34})
|
||||
model6.update({'_scan': 4, '_domain': 11, '_emit': 25, '_region': 35})
|
||||
model7.update({'_scan': 5, '_domain': -1, '_emit': -1, '_region': -1})
|
||||
self.db.insert_result(model2, model2)
|
||||
self.db.insert_result(model3, model3)
|
||||
self.db.insert_result(model4, model4)
|
||||
self.db.insert_result(model5, model5)
|
||||
self.db.insert_result(model6, model6)
|
||||
self.db.insert_result(model7, model7)
|
||||
|
||||
# only model3, model4 and model5 fulfill all conditions and limits
|
||||
fil = ['mode = "testmode"', 'domain = 11']
|
||||
lim = 3
|
||||
result = self.db.query_best_results(filter=fil, limit=lim)
|
||||
|
||||
ifields = ['_db_job', '_db_model', '_db_result',
|
||||
'_model', '_scan', '_domain', '_emit', '_region',
|
||||
'_gen', '_particle']
|
||||
ffields = ['_rfac']
|
||||
dt = [(f, 'i8') for f in ifields]
|
||||
dt.extend([(f, 'f8') for f in ffields])
|
||||
expected = np.zeros((lim,), dtype=dt)
|
||||
expected['_rfac'] = np.array([0.354, 0.354, 0.453])
|
||||
expected['_model'] = np.array([94, 96, 95])
|
||||
expected['_scan'] = np.array([2, 4, 3])
|
||||
expected['_domain'] = np.array([11, 11, 11])
|
||||
expected['_emit'] = np.array([23, 25, 24])
|
||||
expected['_region'] = np.array([33, 35, 34])
|
||||
expected['_gen'] = np.array([1, 1, 1])
|
||||
expected['_particle'] = np.array([4, 6, 5])
|
||||
|
||||
self.assertEqual(result.shape, expected.shape)
|
||||
np.testing.assert_array_almost_equal(result['_rfac'], expected['_rfac'])
|
||||
np.testing.assert_array_equal(result['_model'], expected['_model'])
|
||||
np.testing.assert_array_equal(result['_scan'], expected['_scan'])
|
||||
np.testing.assert_array_equal(result['_domain'], expected['_domain'])
|
||||
np.testing.assert_array_equal(result['_emit'], expected['_emit'])
|
||||
np.testing.assert_array_equal(result['_region'], expected['_region'])
|
||||
np.testing.assert_array_equal(result['_gen'], expected['_gen'])
|
||||
np.testing.assert_array_equal(result['_particle'], expected['_particle'])
|
||||
|
||||
def test_insert_result(self):
|
||||
self.setup_sample_database()
|
||||
index = dispatch.CalcID(15, 16, 17, 18, -1)
|
||||
result = {'parA': 4.123, 'parB': 8.567, '_rfac': 0.654, '_particle': 21}
|
||||
result_id = self.db.insert_result(index, result)
|
||||
|
||||
c = self.db._conn.cursor()
|
||||
c.execute("select count(*) from Results")
|
||||
count = c.fetchone()
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
c.execute("select * from Results")
|
||||
row = c.fetchone()
|
||||
self.assertIsInstance(row['id'], int)
|
||||
self.assertEqual(row['id'], result_id)
|
||||
model_id = row['model_id']
|
||||
self.assertIsInstance(model_id, int)
|
||||
self.assertEqual(row['scan'], index.scan)
|
||||
self.assertEqual(row['domain'], index.domain)
|
||||
self.assertEqual(row['emit'], index.emit)
|
||||
self.assertEqual(row['region'], index.region)
|
||||
self.assertEqual(row['rfac'], result['_rfac'])
|
||||
|
||||
c.execute("select * from Models where id = :model_id", {'model_id': model_id})
|
||||
row = c.fetchone()
|
||||
model_id = row['id']
|
||||
self.assertIsInstance(model_id, int)
|
||||
self.assertEqual(row['job_id'], self.db.job_id)
|
||||
self.assertEqual(row['model'], index.model)
|
||||
self.assertIsNone(row['gen'])
|
||||
self.assertEqual(row['particle'], result['_particle'])
|
||||
|
||||
sql = "select key, value from ParamValues " + \
|
||||
"join Params on ParamValues.param_id = Params.id " + \
|
||||
"where model_id = :model_id"
|
||||
c.execute(sql, {'model_id': model_id})
|
||||
rows = c.fetchall() # list of Row objects
|
||||
self.assertEqual(len(rows), 2)
|
||||
for row in rows:
|
||||
self.assertAlmostEqual(row['value'], result[row['key']])
|
||||
|
||||
def test_update_result(self):
|
||||
self.setup_sample_database()
|
||||
index = dispatch.CalcID(15, 16, 17, 18, -1)
|
||||
result1 = {'parA': 4.123, 'parB': 8.567, '_rfac': 0.654, '_particle': 21}
|
||||
result_id1 = self.db.insert_result(index, result1)
|
||||
result2 = {'parA': 5.456, '_rfac': 0.254, '_particle': 11}
|
||||
result_id2 = self.db.insert_result(index, result2)
|
||||
result3 = result1.copy()
|
||||
result3.update(result2)
|
||||
self.assertEqual(result_id1, result_id2)
|
||||
|
||||
c = self.db._conn.cursor()
|
||||
c.execute("select count(*) from Results")
|
||||
count = c.fetchone()
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
c.execute("select * from Results")
|
||||
row = c.fetchone()
|
||||
self.assertIsInstance(row['id'], int)
|
||||
self.assertEqual(row['id'], result_id1)
|
||||
model_id = row['model_id']
|
||||
self.assertIsInstance(model_id, int)
|
||||
self.assertEqual(row['scan'], index.scan)
|
||||
self.assertEqual(row['domain'], index.domain)
|
||||
self.assertEqual(row['emit'], index.emit)
|
||||
self.assertEqual(row['region'], index.region)
|
||||
self.assertEqual(row['rfac'], result2['_rfac'])
|
||||
|
||||
c.execute("select * from Models where id = :model_id", {'model_id': model_id})
|
||||
row = c.fetchone()
|
||||
model_id = row['id']
|
||||
self.assertIsInstance(model_id, int)
|
||||
self.assertEqual(row['job_id'], self.db.job_id)
|
||||
self.assertEqual(row['model'], index.model)
|
||||
self.assertIsNone(row['gen'])
|
||||
self.assertEqual(row['particle'], result2['_particle'])
|
||||
|
||||
sql = "select key, value from ParamValues " + \
|
||||
"join Params on ParamValues.param_id = Params.id " + \
|
||||
"where model_id = :model_id"
|
||||
c.execute(sql, {'model_id': model_id})
|
||||
rows = c.fetchall() # list of Row objects
|
||||
self.assertEqual(len(rows), 2)
|
||||
for row in rows:
|
||||
self.assertAlmostEqual(row['value'], result3[row['key']])
|
||||
|
||||
def test_update_result_dict(self):
|
||||
"""
|
||||
test update result with index as dictionary
|
||||
|
||||
@return:
|
||||
"""
|
||||
self.setup_sample_database()
|
||||
index = {'_model': 15, '_scan': 16, '_domain': 17, '_emit': 18, '_region': -1}
|
||||
result1 = {'parA': 4.123, 'parB': 8.567, '_rfac': 0.654, '_particle': 21}
|
||||
result_id1 = self.db.insert_result(index, result1)
|
||||
result2 = {'parA': 5.456, '_rfac': 0.254, '_particle': 11}
|
||||
result_id2 = self.db.insert_result(index, result2)
|
||||
result3 = result1.copy()
|
||||
result3.update(result2)
|
||||
self.assertEqual(result_id1, result_id2)
|
||||
|
||||
c = self.db._conn.cursor()
|
||||
c.execute("select count(*) from Results")
|
||||
count = c.fetchone()
|
||||
self.assertEqual(count[0], 1)
|
||||
|
||||
c.execute("select * from Results")
|
||||
row = c.fetchone()
|
||||
self.assertIsInstance(row['id'], int)
|
||||
self.assertEqual(row['id'], result_id1)
|
||||
model_id = row['model_id']
|
||||
self.assertIsInstance(model_id, int)
|
||||
self.assertEqual(row['scan'], index['_scan'])
|
||||
self.assertEqual(row['domain'], index['_domain'])
|
||||
self.assertEqual(row['emit'], index['_emit'])
|
||||
self.assertEqual(row['region'], index['_region'])
|
||||
self.assertEqual(row['rfac'], result2['_rfac'])
|
||||
|
||||
c.execute("select * from Models where id = :model_id", {'model_id': model_id})
|
||||
row = c.fetchone()
|
||||
model_id = row['id']
|
||||
self.assertIsInstance(model_id, int)
|
||||
self.assertEqual(row['job_id'], self.db.job_id)
|
||||
self.assertEqual(row['model'], index['_model'])
|
||||
self.assertIsNone(row['gen'])
|
||||
self.assertEqual(row['particle'], result2['_particle'])
|
||||
|
||||
sql = "select key, value from ParamValues " + \
|
||||
"join Params on ParamValues.param_id = Params.id " + \
|
||||
"where model_id = :model_id"
|
||||
c.execute(sql, {'model_id': model_id})
|
||||
rows = c.fetchall() # list of Row objects
|
||||
self.assertEqual(len(rows), 2)
|
||||
for row in rows:
|
||||
self.assertAlmostEqual(row['value'], result3[row['key']])
|
||||
|
||||
def test_query_best_task_models(self):
|
||||
self.setup_sample_database()
|
||||
model0xxx = {'_model': 0, '_scan': -1, '_domain': -1, '_emit': -1, '_region': -1, 'parA': 4., 'parB': 8.567, '_rfac': 0.01}
|
||||
model00xx = {'_model': 1, '_scan': 0, '_domain': -1, '_emit': -1, '_region': -1, 'parA': 4., 'parB': 8.567, '_rfac': 0.02}
|
||||
model000x = {'_model': 2, '_scan': 0, '_domain': 0, '_emit': -1, '_region': -1, 'parA': 4., 'parB': 8.567, '_rfac': 0.03}
|
||||
model01xx = {'_model': 3, '_scan': 1, '_domain': -1, '_emit': -1, '_region': -1, 'parA': 4., 'parB': 8.567, '_rfac': 0.04}
|
||||
model010x = {'_model': 4, '_scan': 1, '_domain': 0, '_emit': -1, '_region': -1, 'parA': 4., 'parB': 8.567, '_rfac': 0.05}
|
||||
|
||||
model1xxx = {'_model': 5, '_scan': -1, '_domain': -1, '_emit': -1, '_region': -1, 'parA': 4.123, 'parB': 8.567, '_rfac': 0.09}
|
||||
model10xx = {'_model': 6, '_scan': 0, '_domain': -1, '_emit': -1, '_region': -1, 'parA': 4.123, 'parB': 8.567, '_rfac': 0.08}
|
||||
model100x = {'_model': 7, '_scan': 0, '_domain': 0, '_emit': -1, '_region': -1, 'parA': 4.123, 'parB': 8.567, '_rfac': 0.07}
|
||||
model11xx = {'_model': 8, '_scan': 1, '_domain': -1, '_emit': -1, '_region': -1, 'parA': 4.123, 'parB': 8.567, '_rfac': 0.06}
|
||||
model110x = {'_model': 9, '_scan': 1, '_domain': 0, '_emit': -1, '_region': -1, 'parA': 4.123, 'parB': 8.567, '_rfac': 0.05}
|
||||
|
||||
model2xxx = {'_model': 10, '_scan': -1, '_domain': -1, '_emit': -1, '_region': -1, 'parA': 4.123, 'parB': 8.567, '_rfac': 0.01}
|
||||
|
||||
self.db.insert_result(model0xxx, model0xxx)
|
||||
self.db.insert_result(model00xx, model00xx)
|
||||
self.db.insert_result(model000x, model000x)
|
||||
self.db.insert_result(model01xx, model01xx)
|
||||
self.db.insert_result(model010x, model010x)
|
||||
|
||||
self.db.insert_result(model1xxx, model1xxx)
|
||||
self.db.insert_result(model10xx, model10xx)
|
||||
self.db.insert_result(model100x, model100x)
|
||||
self.db.insert_result(model11xx, model11xx)
|
||||
self.db.insert_result(model110x, model110x)
|
||||
|
||||
self.db.insert_result(model2xxx, model2xxx)
|
||||
|
||||
result = self.db.query_best_task_models(level=1, count=2)
|
||||
|
||||
expected = {0, 1, 3, 6, 8, 10}
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_sample_project(self):
|
||||
"""
|
||||
test ingestion of two results
|
||||
|
||||
this test uses the same call sequence as the actual pmsco code.
|
||||
it has been used to debug a problem in the main code
|
||||
where prevous results were overwritten.
|
||||
"""
|
||||
db_filename = os.path.join(self.test_dir, "sample_database.db")
|
||||
lock_filename = os.path.join(self.test_dir, "sample_database.lock")
|
||||
|
||||
# project
|
||||
project_name = self.__class__.__name__
|
||||
project_module = self.__class__.__module__
|
||||
|
||||
# job 1
|
||||
job_name1 = "job1"
|
||||
result1 = {'parA': 1.234, 'parB': 5.678, '_model': 91, '_rfac': 0.534}
|
||||
task1 = dispatch.CalcID(91, -1, -1, -1, -1)
|
||||
|
||||
# ingest job 1
|
||||
_db = db.ResultsDatabase()
|
||||
_db.connect(db_filename, lock_filename=lock_filename)
|
||||
project_id1 = _db.register_project(project_name, project_module)
|
||||
job_id1 = _db.register_job(project_id1, job_name1, "test", "localhost", "", datetime.datetime.now(), "")
|
||||
# _db.insert_jobtags(job_id, self.job_tags)
|
||||
_db.register_params(result1.keys())
|
||||
_db.create_models_view()
|
||||
result_id1 = _db.insert_result(task1, result1)
|
||||
_db.disconnect()
|
||||
|
||||
# job 2
|
||||
job_name2 = "job2"
|
||||
result2 = {'parA': 1.345, 'parB': 5.789, '_model': 91, '_rfac': 0.654}
|
||||
task2 = dispatch.CalcID(91, -1, -1, -1, -1)
|
||||
|
||||
# ingest job 2
|
||||
_db = db.ResultsDatabase()
|
||||
_db.connect(db_filename, lock_filename=lock_filename)
|
||||
project_id2 = _db.register_project(project_name, project_module)
|
||||
job_id2 = _db.register_job(project_id2, job_name2, "test", "localhost", "", datetime.datetime.now(), "")
|
||||
# _db.insert_jobtags(job_id, self.job_tags)
|
||||
_db.register_params(result2.keys())
|
||||
_db.create_models_view()
|
||||
result_id2 = _db.insert_result(task2, result2)
|
||||
_db.disconnect()
|
||||
|
||||
# check jobs
|
||||
_db = db.ResultsDatabase()
|
||||
_db.connect(db_filename, lock_filename=lock_filename)
|
||||
sql = "select * from Jobs "
|
||||
c = _db._conn.execute(sql)
|
||||
rows = c.fetchall()
|
||||
self.assertEqual(len(rows), 2)
|
||||
|
||||
# check models
|
||||
sql = "select * from Models "
|
||||
c = _db._conn.execute(sql)
|
||||
rows = c.fetchall()
|
||||
self.assertEqual(len(rows), 2)
|
||||
|
||||
# check results
|
||||
sql = "select * from Results "
|
||||
c = _db._conn.execute(sql)
|
||||
rows = c.fetchall()
|
||||
self.assertEqual(len(rows), 2)
|
||||
|
||||
_db.disconnect()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -17,11 +17,6 @@ Licensed under the Apache License, Version 2.0 (the "License"); @n
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
import six
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import os.path
|
||||
|
||||
@@ -10,7 +10,7 @@ to run the tests, change to the directory which contains the tests directory, an
|
||||
|
||||
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
|
||||
|
||||
@copyright (c) 2015-21 by Paul Scherrer Institut @n
|
||||
@copyright (c) 2015-25 by Paul Scherrer Institut @n
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); @n
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
@@ -19,13 +19,13 @@ Licensed under the Apache License, Version 2.0 (the "License"); @n
|
||||
|
||||
import mock
|
||||
import numpy as np
|
||||
import os
|
||||
from pathlib import Path
|
||||
import unittest
|
||||
|
||||
import pmsco.data as data
|
||||
import pmsco.dispatch as dispatch
|
||||
import pmsco.project as project
|
||||
from pmsco.scan import Scan
|
||||
|
||||
|
||||
class TestModelSpace(unittest.TestCase):
|
||||
@@ -37,6 +37,20 @@ class TestModelSpace(unittest.TestCase):
|
||||
"C": {"start": 22.0, "min": 15.0, "max": 25.0, "step": 1.0},
|
||||
"D": {"start": 1.5, "min": 0.5, "max": 2.0, "step": 0.25}}
|
||||
|
||||
def test_eval_param_value(self):
|
||||
dummy_value = 15.3
|
||||
ms = project.ModelSpace()
|
||||
ms.project_symbols = {'numpy': np}
|
||||
self.assertAlmostEqual(ms._eval_param_value(0.01), 0.01)
|
||||
self.assertAlmostEqual(ms._eval_param_value('0.01'), 0.01)
|
||||
self.assertAlmostEqual(ms._eval_param_value('numpy.sin(0.1)'), np.sin(0.1))
|
||||
self.assertAlmostEqual(ms._eval_param_value('abs(-0.1)'), 0.1)
|
||||
self.assertTrue(np.isnan(ms._eval_param_value(None)))
|
||||
self.assertTrue(np.isnan(ms._eval_param_value(np.nan)))
|
||||
self.assertRaises(ValueError, ms._eval_param_value, '')
|
||||
# should not have access to local symbols
|
||||
self.assertRaises(NameError, ms._eval_param_value, 'dummy_value')
|
||||
|
||||
def test_add_param(self):
|
||||
ms = project.ModelSpace()
|
||||
ms.start['A'] = 2.1
|
||||
@@ -80,105 +94,6 @@ class TestModelSpace(unittest.TestCase):
|
||||
self.assertDictEqual(ms.step, d_step)
|
||||
|
||||
|
||||
class TestScanCreator(unittest.TestCase):
|
||||
"""
|
||||
test case for @ref pmsco.project.ScanCreator class
|
||||
|
||||
"""
|
||||
def test_load_1(self):
|
||||
"""
|
||||
test the load method, case 1
|
||||
|
||||
test for:
|
||||
- correct array expansion of an ['e', 'a'] scan.
|
||||
- correct file name expansion with place holders and pathlib.Path objects.
|
||||
"""
|
||||
sc = project.ScanCreator()
|
||||
sc.filename = Path("{test_p}", "twoatom_energy_alpha.etpai")
|
||||
sc.positions = {
|
||||
"e": "np.arange(10, 400, 5)",
|
||||
"t": "0",
|
||||
"p": "0",
|
||||
"a": "np.linspace(-30, 30, 31)"
|
||||
}
|
||||
sc.emitter = "Cu"
|
||||
sc.initial_state = "2p3/2"
|
||||
|
||||
p = Path(__file__).parent / ".." / "projects" / "twoatom"
|
||||
dirs = {"test_p": p,
|
||||
"test_s": str(p)}
|
||||
|
||||
result = sc.load(dirs=dirs)
|
||||
|
||||
self.assertEqual(result.mode, ['e', 'a'])
|
||||
self.assertEqual(result.emitter, sc.emitter)
|
||||
self.assertEqual(result.initial_state, sc.initial_state)
|
||||
|
||||
e = np.arange(10, 400, 5)
|
||||
a = np.linspace(-30, 30, 31)
|
||||
t = p = np.asarray([0])
|
||||
np.testing.assert_array_equal(result.energies, e)
|
||||
np.testing.assert_array_equal(result.thetas, t)
|
||||
np.testing.assert_array_equal(result.phis, p)
|
||||
np.testing.assert_array_equal(result.alphas, a)
|
||||
|
||||
self.assertTrue(Path(result.filename).is_file(), msg=f"file {result.filename} not found")
|
||||
|
||||
|
||||
class TestScan(unittest.TestCase):
|
||||
"""
|
||||
test case for @ref pmsco.project.Scan class
|
||||
|
||||
"""
|
||||
def test_import_scan_file(self):
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
test_file = os.path.join(base_dir, "..", "projects", "twoatom", "twoatom_energy_alpha.etpai")
|
||||
|
||||
scan = project.Scan()
|
||||
scan.import_scan_file(test_file, "C", "1s")
|
||||
|
||||
mode = ['e', 'a']
|
||||
self.assertEqual(scan.mode, mode)
|
||||
|
||||
ae = np.arange(10, 1005, 5)
|
||||
at = np.asarray([0])
|
||||
ap = np.asarray([0])
|
||||
aa = np.arange(-90, 91, 1)
|
||||
|
||||
np.testing.assert_array_almost_equal(scan.energies, ae)
|
||||
np.testing.assert_array_almost_equal(scan.thetas, at)
|
||||
np.testing.assert_array_almost_equal(scan.phis, ap)
|
||||
np.testing.assert_array_almost_equal(scan.alphas, aa)
|
||||
|
||||
def test_define_scan(self):
|
||||
scan = project.Scan()
|
||||
p0 = np.asarray([20])
|
||||
p1 = np.linspace(1, 4, 4)
|
||||
p2 = np.linspace(11, 13, 3)
|
||||
d = {'t': p1, 'e': p0, 'p': p2}
|
||||
scan.define_scan(d, "C", "1s")
|
||||
|
||||
ae = np.asarray([20])
|
||||
at = np.asarray([1, 2, 3, 4])
|
||||
ap = np.asarray([11, 12, 13])
|
||||
aa = np.asarray([0])
|
||||
|
||||
np.testing.assert_array_almost_equal(scan.energies, ae)
|
||||
np.testing.assert_array_almost_equal(scan.thetas, at)
|
||||
np.testing.assert_array_almost_equal(scan.phis, ap)
|
||||
np.testing.assert_array_almost_equal(scan.alphas, aa)
|
||||
|
||||
re = np.ones(12) * 20
|
||||
rt = np.asarray([1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4])
|
||||
rp = np.asarray([11, 12, 13, 11, 12, 13, 11, 12, 13, 11, 12, 13])
|
||||
ra = np.ones(12) * 0
|
||||
|
||||
np.testing.assert_array_almost_equal(scan.raw_data['e'], re)
|
||||
np.testing.assert_array_almost_equal(scan.raw_data['t'], rt)
|
||||
np.testing.assert_array_almost_equal(scan.raw_data['p'], rp)
|
||||
np.testing.assert_array_almost_equal(scan.raw_data['a'], ra)
|
||||
|
||||
|
||||
class TestProject(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# before each test method
|
||||
@@ -198,10 +113,41 @@ class TestProject(unittest.TestCase):
|
||||
# teardown_class() after any methods in this class
|
||||
pass
|
||||
|
||||
def test_resolve_directories(self):
|
||||
self.project.job_name = "jn1"
|
||||
self.project.job_tags['jt1'] = 'tag1'
|
||||
self.project.directories['report'] = Path("${output}/reports")
|
||||
self.project.directories['output'] = Path("${home}/test/output")
|
||||
self.project.directories.resolve_directories(check=True)
|
||||
|
||||
expected = Path(Path.home(), "test", "output")
|
||||
self.assertEqual(expected, self.project.directories['output'])
|
||||
expected = Path(Path.home(), "test", "output", "reports")
|
||||
self.assertEqual(expected, self.project.directories['report'])
|
||||
|
||||
def test_resolve_path(self):
|
||||
self.project.job_name = "jn1"
|
||||
self.project.job_tags['jt1'] = 'tag1'
|
||||
self.project.directories['output'] = Path.home() / "test" / "output"
|
||||
self.project.directories['report'] = self.project.directories['output'] / "reports"
|
||||
extra = {'param_name': 'A', 'value': 25.6}
|
||||
template = "${report}/${job_name}-${jt1}-${param_name}"
|
||||
rps = self.project.directories.resolve_path(template, extra)
|
||||
rpp = self.project.directories.resolve_path(Path(template), extra)
|
||||
expected = Path(Path.home(), "test", "output", "reports", "jn1-tag1-A")
|
||||
self.assertEqual(str(expected), rps)
|
||||
self.assertEqual(expected, rpp)
|
||||
|
||||
fdict = {'base': rpp.stem, 'gen': 14, 'param0': 'A', 'param1': 'B'}
|
||||
template = "my_calc gen ${gen}"
|
||||
title = self.project.directories.resolve_path(template, fdict)
|
||||
expected = "my_calc gen 14"
|
||||
self.assertEqual(expected, title)
|
||||
|
||||
@mock.patch('pmsco.data.load_data')
|
||||
@mock.patch('pmsco.data.save_data')
|
||||
def test_combine_domains(self, save_data_mock, load_data_mock):
|
||||
self.project.scans.append(project.Scan())
|
||||
self.project.scans.append(Scan())
|
||||
|
||||
parent_task = dispatch.CalculationTask()
|
||||
parent_task.change_id(model=0, scan=0)
|
||||
|
||||
141
tests/test_scan.py
Normal file
141
tests/test_scan.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
@package tests.test_scan
|
||||
unit tests for pmsco.scan.
|
||||
|
||||
the purpose of these tests is to help debugging the code.
|
||||
|
||||
to run the tests, change to the directory which contains the tests directory, and execute =nosetests=.
|
||||
|
||||
@pre nose and mock must be installed.
|
||||
|
||||
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
|
||||
|
||||
@copyright (c) 2015-21 by Paul Scherrer Institut @n
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); @n
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
from pathlib import Path
|
||||
import unittest
|
||||
|
||||
from pmsco.data import holo_grid
|
||||
from pmsco.scan import Scan, ScanKey, ScanLoader, ScanCreator
|
||||
|
||||
|
||||
class TestScanCreator(unittest.TestCase):
|
||||
"""
|
||||
test case for @ref pmsco.project.ScanCreator class
|
||||
|
||||
"""
|
||||
def test_load_1(self):
|
||||
"""
|
||||
test the load method, case 1
|
||||
|
||||
test for:
|
||||
- correct array expansion of an ['e', 'a'] scan.
|
||||
- correct file name expansion with place holders and pathlib.Path objects.
|
||||
"""
|
||||
sc = ScanCreator()
|
||||
sc.filename = Path("${test_p}", "twoatom_energy_alpha.etpai")
|
||||
sc.positions = {
|
||||
"e": "np.arange(10, 400, 5)",
|
||||
"t": "0",
|
||||
"p": "0",
|
||||
"a": "np.linspace(-30, 30, 31)"
|
||||
}
|
||||
sc.emitter = "Cu"
|
||||
sc.initial_state = "2p3/2"
|
||||
|
||||
p = Path(__file__).parent.parent / "pmsco" / "projects" / "twoatom"
|
||||
dirs = {"test_p": p,
|
||||
"test_s": str(p)}
|
||||
|
||||
result = sc.load(dirs=dirs)
|
||||
|
||||
self.assertEqual(result.mode, ['e', 'a'])
|
||||
self.assertEqual(result.emitter, sc.emitter)
|
||||
self.assertEqual(result.initial_state, sc.initial_state)
|
||||
|
||||
e = np.arange(10, 400, 5)
|
||||
a = np.linspace(-30, 30, 31)
|
||||
t = p = np.asarray([0])
|
||||
np.testing.assert_array_equal(result.energies, e)
|
||||
np.testing.assert_array_equal(result.thetas, t)
|
||||
np.testing.assert_array_equal(result.phis, p)
|
||||
np.testing.assert_array_equal(result.alphas, a)
|
||||
|
||||
self.assertTrue(Path(result.filename).is_file(), msg=f"file {result.filename} not found")
|
||||
|
||||
|
||||
class TestScan(unittest.TestCase):
|
||||
"""
|
||||
test case for @ref pmsco.project.Scan class
|
||||
|
||||
"""
|
||||
def test_import_scan_file(self):
|
||||
base_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
test_file = os.path.join(base_dir, "..", "pmsco", "projects", "twoatom", "twoatom_energy_alpha.etpai")
|
||||
|
||||
scan = Scan()
|
||||
scan.import_scan_file(test_file, "C", "1s")
|
||||
|
||||
mode = ['e', 'a']
|
||||
self.assertEqual(scan.mode, mode)
|
||||
|
||||
ae = np.arange(10, 1005, 5)
|
||||
at = np.asarray([0])
|
||||
ap = np.asarray([0])
|
||||
aa = np.arange(-90, 91, 1)
|
||||
|
||||
np.testing.assert_array_almost_equal(scan.energies, ae)
|
||||
np.testing.assert_array_almost_equal(scan.thetas, at)
|
||||
np.testing.assert_array_almost_equal(scan.phis, ap)
|
||||
np.testing.assert_array_almost_equal(scan.alphas, aa)
|
||||
|
||||
def test_define_scan(self):
|
||||
scan = Scan()
|
||||
p0 = np.asarray([20])
|
||||
p1 = np.linspace(1, 4, 4)
|
||||
p2 = np.linspace(11, 13, 3)
|
||||
d = {'t': p1, 'e': p0, 'p': p2}
|
||||
scan.define_scan(d, "C", "1s")
|
||||
|
||||
ae = np.asarray([20])
|
||||
at = np.asarray([1, 2, 3, 4])
|
||||
ap = np.asarray([11, 12, 13])
|
||||
aa = np.asarray([0])
|
||||
|
||||
np.testing.assert_array_almost_equal(scan.energies, ae)
|
||||
np.testing.assert_array_almost_equal(scan.thetas, at)
|
||||
np.testing.assert_array_almost_equal(scan.phis, ap)
|
||||
np.testing.assert_array_almost_equal(scan.alphas, aa)
|
||||
|
||||
re = np.ones(12) * 20
|
||||
rt = np.asarray([1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4])
|
||||
rp = np.asarray([11, 12, 13, 11, 12, 13, 11, 12, 13, 11, 12, 13])
|
||||
ra = np.ones(12) * 0
|
||||
|
||||
np.testing.assert_array_almost_equal(scan.raw_data['e'], re)
|
||||
np.testing.assert_array_almost_equal(scan.raw_data['t'], rt)
|
||||
np.testing.assert_array_almost_equal(scan.raw_data['p'], rp)
|
||||
np.testing.assert_array_almost_equal(scan.raw_data['a'], ra)
|
||||
|
||||
def test_generate_holo_scan(self):
|
||||
scan = Scan()
|
||||
scan.generate_holo_scan(generator=holo_grid,
|
||||
generator_args={},
|
||||
other_positions={"e": 250, "a": 5},
|
||||
emitter="C", initial_state="1s")
|
||||
|
||||
self.assertEqual(scan.thetas.shape, (16376,))
|
||||
self.assertEqual(scan.phis.shape, (16376,))
|
||||
np.testing.assert_array_almost_equal(scan.alphas, np.asarray((5,)))
|
||||
np.testing.assert_array_almost_equal(scan.energies, np.asarray((250,)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
76
tests/test_schedule.py
Normal file
76
tests/test_schedule.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
import unittest
|
||||
|
||||
from pmsco.schedule import JobSchedule, SlurmSchedule, PsiRaSchedule
|
||||
|
||||
|
||||
class TestSlurmSchedule(unittest.TestCase):
|
||||
def test_parse_timedelta(self):
|
||||
"""
|
||||
@param td:
|
||||
str: [days-]hours[:minutes[:seconds]]
|
||||
dict: days, hours, minutes, seconds - at least one needs to be defined. values must be numeric.
|
||||
datetime.timedelta - native type
|
||||
@return: datetime.timedelta
|
||||
|
||||
"""
|
||||
|
||||
input = "1-15:20:23"
|
||||
expected = datetime.timedelta(days=1, hours=15, minutes=20, seconds=23)
|
||||
self.assertEqual(SlurmSchedule.parse_timedelta(input), expected)
|
||||
input = {"days": 1, "hours": 15, "minutes": 20, "seconds": 23}
|
||||
self.assertEqual(SlurmSchedule.parse_timedelta(input), expected)
|
||||
input = {"days": "1", "hours": "15", "minutes": "20", "seconds": "23"}
|
||||
self.assertEqual(SlurmSchedule.parse_timedelta(input), expected)
|
||||
|
||||
input = "15"
|
||||
expected = datetime.timedelta(hours=15)
|
||||
self.assertEqual(SlurmSchedule.parse_timedelta(input), expected)
|
||||
input = {"hours": 15}
|
||||
self.assertEqual(SlurmSchedule.parse_timedelta(input), expected)
|
||||
|
||||
input = "12:00"
|
||||
expected = datetime.timedelta(hours=12, minutes=0)
|
||||
self.assertEqual(SlurmSchedule.parse_timedelta(input), expected)
|
||||
|
||||
input = "15:20:23"
|
||||
expected = datetime.timedelta(hours=15, minutes=20, seconds=23)
|
||||
self.assertEqual(SlurmSchedule.parse_timedelta(input), expected)
|
||||
|
||||
input = "1-15"
|
||||
expected = datetime.timedelta(days=1, hours=15)
|
||||
self.assertEqual(SlurmSchedule.parse_timedelta(input), expected)
|
||||
|
||||
input = "1-15:20"
|
||||
expected = datetime.timedelta(days=1, hours=15, minutes=20)
|
||||
self.assertEqual(SlurmSchedule.parse_timedelta(input), expected)
|
||||
|
||||
input = {"days": 1}
|
||||
expected = datetime.timedelta(days=1)
|
||||
self.assertEqual(SlurmSchedule.parse_timedelta(input), expected)
|
||||
input = "24:00"
|
||||
self.assertEqual(SlurmSchedule.parse_timedelta(input), expected)
|
||||
|
||||
expected = datetime.timedelta(days=2)
|
||||
input = "48:00"
|
||||
self.assertEqual(SlurmSchedule.parse_timedelta(input), expected)
|
||||
|
||||
input = {"minutes": 20}
|
||||
expected = datetime.timedelta(minutes=20)
|
||||
self.assertEqual(SlurmSchedule.parse_timedelta(input), expected)
|
||||
|
||||
input = {"seconds": 23}
|
||||
expected = datetime.timedelta(seconds=23)
|
||||
self.assertEqual(SlurmSchedule.parse_timedelta(input), expected)
|
||||
|
||||
def test_detect_env(self):
|
||||
result = SlurmSchedule.detect_env()
|
||||
self.assertTrue(result, "undetectable environment")
|
||||
for key, value in result.items():
|
||||
self.assertTrue(key in {"conda", "venv", "system"}, "unknown environment type")
|
||||
self.assertTrue(Path(value).is_dir())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -96,7 +96,7 @@ class TestSwarmPopulation(unittest.TestCase):
|
||||
pos1 = self.pop.pos['A'][0]
|
||||
self.assertNotAlmostEqual(pos0, pos1, delta=0.001)
|
||||
|
||||
for key in ['A','B','C']:
|
||||
for key in ['A', 'B', 'C']:
|
||||
for pos in self.pop.pos[key]:
|
||||
self.assertGreaterEqual(pos, self.model_space.min[key])
|
||||
self.assertLessEqual(pos, self.model_space.max[key])
|
||||
|
||||
75
tests/transforms/test_multipoles.py
Normal file
75
tests/transforms/test_multipoles.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
@package tests.test_data
|
||||
unit tests for pmsco.data
|
||||
|
||||
the purpose of these tests is to mainly to check the syntax, and correct data types,
|
||||
i.e. anything that could cause a run-time error.
|
||||
calculation results are sometimes checked for plausibility but not exact values,
|
||||
depending on the level of debugging required for a specific part of the code.
|
||||
|
||||
to run the tests, change to the directory which contains the tests directory, and execute =nosetests=.
|
||||
|
||||
@pre nose must be installed (python-nose package on Debian).
|
||||
|
||||
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
|
||||
|
||||
@copyright (c) 2015-24 by Paul Scherrer Institut @n
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); @n
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
import pmsco.data as md
|
||||
from pmsco.transforms.multipoles import MultipoleExpansion
|
||||
|
||||
|
||||
class TestMultipoleExpansion(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# before each test method
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
# after each test method
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
# before any methods in this class
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
# teardown_class() after any methods in this class
|
||||
pass
|
||||
|
||||
def test_generate_expand(self):
|
||||
orig_data = md.holo_array(md.holo_grid, {"theta_step": 2.}, datatype="ITP")
|
||||
orig_data["i"] = np.cos(np.deg2rad(orig_data["t"])) * np.sin(np.deg2rad(orig_data["p"]) * 3) ** 2
|
||||
|
||||
lmax = 12
|
||||
me = MultipoleExpansion()
|
||||
me.holoscan = orig_data
|
||||
me.lmax = lmax
|
||||
alm = me.generate()
|
||||
|
||||
self.assertEqual((lmax / 2 + 1, lmax * 2 + 1), alm.shape)
|
||||
self.assertTrue(np.any(np.real(alm) > 0.), "real part non-zero?")
|
||||
self.assertTrue(np.any(np.imag(alm) > 0.), "imaginary part non-zero?")
|
||||
|
||||
me.expand()
|
||||
expanded_data = me.expansion
|
||||
|
||||
self.assertEqual(expanded_data.shape, orig_data.shape)
|
||||
self.assertTrue(np.any(np.abs(expanded_data['i']) > 0.), "output array non-zero?")
|
||||
|
||||
rf = md.square_diff_rfactor(expanded_data, orig_data)
|
||||
print(rf)
|
||||
self.assertLess(rf, 0.25)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user