233 lines
10 KiB
Python
233 lines
10 KiB
Python
"""
|
|
@package tests.database.test_common
|
|
unit tests for pmsco.database.common
|
|
|
|
the purpose of these tests is to help debugging the code.
|
|
|
|
to run the tests, change to the directory which contains the tests directory, and execute =nosetests=.
|
|
|
|
@pre nose must be installed (python-nose package on Debian).
|
|
|
|
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
|
|
|
|
@copyright (c) 2016 by Paul Scherrer Institut @n
|
|
Licensed under the Apache License, Version 2.0 (the "License"); @n
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
"""
|
|
|
|
import unittest
|
|
import datetime
|
|
import sqlalchemy.exc
|
|
import pmsco.database.access as db
|
|
import pmsco.database.common as db_common
|
|
import pmsco.database.orm as orm
|
|
import pmsco.dispatch as dispatch
|
|
|
|
|
|
def setup_sample_database(session):
|
|
p1 = orm.Project(name="oldproject", code="oldcode")
|
|
p2 = orm.Project(name="unittest", code="testcode")
|
|
j1 = orm.Job(project=p1, name="oldjob", mode="oldmode", machine="oldhost", datetime=datetime.datetime.now())
|
|
j2 = orm.Job(project=p2, name="testjob", mode="testmode", machine="testhost", datetime=datetime.datetime.now())
|
|
pk1 = orm.Param(key='parA')
|
|
pk2 = orm.Param(key='parB')
|
|
pk3 = orm.Param(key='parC')
|
|
m1 = orm.Model(job=j1, model=91)
|
|
m2 = orm.Model(job=j2, model=92)
|
|
r1 = orm.Result(calc_id=dispatch.CalcID(91, -1, -1, -1, -1), rfac=0.534, secs=37.9)
|
|
r1.model = m1
|
|
pv1 = orm.ParamValue(model=m1, param=pk1, value=1.234, delta=0.1234)
|
|
pv2 = orm.ParamValue(model=m1, param=pk2, value=5.678, delta=-0.5678)
|
|
pv3 = orm.ParamValue(model=m2, param=pk3, value=6.785, delta=0.6785)
|
|
objects = {'p1': p1, 'p2': p2, 'j1': j1, 'j2': j2, 'm1': m1, 'm2': m2, 'r1': r1,
|
|
'pv1': pv1, 'pv2': pv2, 'pv3': pv3, 'pk1': pk1, 'pk2': pk2, 'pk3': pk3}
|
|
session.add_all(objects.values())
|
|
session.commit()
|
|
return objects
|
|
|
|
|
|
class TestDatabaseCommon(unittest.TestCase):
|
|
def setUp(self):
|
|
self.db = db.DatabaseAccess()
|
|
self.db.connect(":memory:")
|
|
|
|
def tearDown(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def setup_class(cls):
|
|
# before any methods in this class
|
|
pass
|
|
|
|
@classmethod
|
|
def teardown_class(cls):
|
|
# teardown_class() after any methods in this class
|
|
pass
|
|
|
|
def test_setup_database(self):
|
|
with self.db.session() as session:
|
|
setup_sample_database(session)
|
|
self.assertEqual(session.query(orm.Project).count(), 2)
|
|
self.assertEqual(session.query(orm.Job).count(), 2)
|
|
self.assertEqual(session.query(orm.Param).count(), 3)
|
|
self.assertEqual(session.query(orm.Model).count(), 2)
|
|
self.assertEqual(session.query(orm.Result).count(), 1)
|
|
self.assertEqual(session.query(orm.ParamValue).count(), 3)
|
|
|
|
def test_get_project(self):
|
|
with self.db.session() as session:
|
|
p1 = orm.Project(name="p1")
|
|
p2 = orm.Project(name="p2")
|
|
p3 = orm.Project(name="p3")
|
|
p4 = orm.Project(name="p4")
|
|
session.add_all([p1, p2, p3])
|
|
session.commit()
|
|
q1 = db_common.get_project(session, p1)
|
|
q2 = db_common.get_project(session, p2.id)
|
|
q3 = db_common.get_project(session, p3.name)
|
|
q4 = db_common.get_project(session, p4)
|
|
self.assertIs(q1, p1, "by object")
|
|
self.assertIs(q2, p2, "by id")
|
|
self.assertIs(q3, p3, "by name")
|
|
self.assertIs(q4, p4, "detached object by object")
|
|
with self.assertRaises(sqlalchemy.exc.InvalidRequestError, msg="detached object by name"):
|
|
db_common.get_project(session, p4.name)
|
|
|
|
def test_get_job(self):
|
|
with self.db.session() as session:
|
|
p1 = orm.Project(name="p1")
|
|
p2 = orm.Project(name="p2")
|
|
p3 = orm.Project(name="p3")
|
|
p4 = orm.Project(name="p4")
|
|
j1 = orm.Job(name="j1")
|
|
j1.project = p1
|
|
j2 = orm.Job(name="j2")
|
|
j2.project = p2
|
|
j3 = orm.Job(name="j1")
|
|
j3.project = p3
|
|
j4 = orm.Job(name="j4")
|
|
j4.project = p4
|
|
session.add_all([p1, p2, p3, j1, j2, j3])
|
|
session.commit()
|
|
|
|
self.assertIsNot(j3, j1, "jobs with same name")
|
|
q1 = db_common.get_job(session, p1, j1)
|
|
q2 = db_common.get_job(session, p2, j2.id)
|
|
q3 = db_common.get_job(session, p3, j3.name)
|
|
q4 = db_common.get_job(session, p4, j4)
|
|
self.assertIs(q1, j1, "by object")
|
|
self.assertIs(q2, j2, "by id")
|
|
self.assertIs(q3, j3, "by name")
|
|
self.assertIs(q4, j4, "detached object by object")
|
|
with self.assertRaises(sqlalchemy.exc.InvalidRequestError, msg="detached object by name"):
|
|
db_common.get_job(session, p4, j4.name)
|
|
q5 = db_common.get_job(session, p1, j4)
|
|
self.assertIs(q5, j4)
|
|
|
|
def test_register_project(self):
|
|
with self.db.session() as session:
|
|
id1 = db_common.register_project(session, "unittest1", "Atest", allow_existing=True)
|
|
self.assertIsInstance(id1, orm.Project)
|
|
id2 = db_common.register_project(session, "unittest2", "Btest", allow_existing=True)
|
|
self.assertIsInstance(id2, orm.Project)
|
|
id3 = db_common.register_project(session, "unittest1", "Ctest", allow_existing=True)
|
|
self.assertIsInstance(id3, orm.Project)
|
|
self.assertNotEqual(id1, id2)
|
|
self.assertEqual(id1, id3)
|
|
session.commit()
|
|
|
|
c = session.execute("select count(*) from Projects")
|
|
row = c.fetchone()
|
|
self.assertEqual(row[0], 2)
|
|
c = session.execute("select name, code from Projects where id=:id", {'id': id1.id})
|
|
row = c.fetchone()
|
|
self.assertIsNotNone(row)
|
|
self.assertEqual(len(row), 2)
|
|
self.assertEqual(row[0], "unittest1")
|
|
self.assertEqual(row[1], "Atest")
|
|
self.assertEqual(row['name'], "unittest1")
|
|
self.assertEqual(row['code'], "Atest")
|
|
|
|
with self.assertRaises(ValueError):
|
|
db_common.register_project(session, "unittest1", "Ctest")
|
|
|
|
def test_register_job(self):
|
|
with self.db.session() as session:
|
|
pid1 = db_common.register_project(session, "unittest1", "Acode")
|
|
pid2 = db_common.register_project(session, "unittest2", "Bcode")
|
|
dt1 = datetime.datetime.now()
|
|
|
|
# insert new job
|
|
id1 = db_common.register_job(session, pid1, "Ajob", mode="Amode", machine="local", git_hash="Ahash",
|
|
datetime=dt1, description="Adesc")
|
|
self.assertIsInstance(id1, orm.Job)
|
|
# insert another job
|
|
id2 = db_common.register_job(session, pid1.id, "Bjob", mode="Bmode", machine="local", git_hash="Ahash",
|
|
datetime=dt1, description="Adesc")
|
|
self.assertIsInstance(id2, orm.Job)
|
|
# update first job
|
|
id3 = db_common.register_job(session, "unittest1", "Ajob", mode="Cmode", machine="local", git_hash="Chash",
|
|
datetime=dt1, description="Cdesc",
|
|
allow_existing=True)
|
|
self.assertIsInstance(id3, orm.Job)
|
|
# insert another job with same name but in other project
|
|
id4 = db_common.register_job(session, pid2, "Ajob", mode="Dmode", machine="local", git_hash="Dhash",
|
|
datetime=dt1, description="Ddesc")
|
|
self.assertIsInstance(id4, orm.Job)
|
|
# existing job
|
|
with self.assertRaises(ValueError):
|
|
db_common.register_job(session, pid1, "Ajob", mode="Emode", machine="local", git_hash="Dhash",
|
|
datetime=dt1, description="Ddesc")
|
|
|
|
self.assertIsNot(id1, id2)
|
|
self.assertIs(id1, id3)
|
|
self.assertIsNot(id1, id4)
|
|
|
|
c = session.execute("select count(*) from Jobs")
|
|
row = c.fetchone()
|
|
self.assertEqual(row[0], 3)
|
|
c = session.execute("select name, mode, machine, git_hash, datetime, description from Jobs where id=:id",
|
|
{'id': id1.id})
|
|
row = c.fetchone()
|
|
self.assertIsNotNone(row)
|
|
self.assertEqual(len(row), 6)
|
|
self.assertEqual(row[0], "Ajob")
|
|
self.assertEqual(row[1], "Amode")
|
|
self.assertEqual(row['machine'], "local")
|
|
self.assertEqual(str(row['datetime']), str(dt1))
|
|
self.assertEqual(row['git_hash'], "Ahash")
|
|
self.assertEqual(row['description'], "Adesc")
|
|
|
|
def test_register_params(self):
|
|
with self.db.session() as session:
|
|
setup_sample_database(session)
|
|
model5 = {'parA': 2.341, 'parC': 6.785, '_model': 92, '_rfac': 0.453}
|
|
db_common.register_params(session, model5)
|
|
expected = ['parA', 'parB', 'parC']
|
|
session.commit()
|
|
|
|
c = session.execute("select * from Params order by key")
|
|
results = c.fetchall()
|
|
self.assertEqual(len(results), 3)
|
|
result_params = [row['key'] for row in results]
|
|
self.assertEqual(result_params, expected)
|
|
|
|
def test_query_params(self):
|
|
with self.db.session() as session:
|
|
objs = setup_sample_database(session)
|
|
results = db_common.query_params(session, project=objs['p1'].id)
|
|
expected = ['parA', 'parB']
|
|
self.assertEqual(expected, sorted(list(results.keys())))
|
|
self.assertIsInstance(results['parA'], orm.Param)
|
|
self.assertIsInstance(results['parB'], orm.Param)
|
|
results = db_common.query_params(session, project=objs['p2'].name)
|
|
expected = ['parC']
|
|
self.assertEqual(expected, sorted(list(results.keys())))
|
|
self.assertIsInstance(results['parC'], orm.Param)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|