280 lines
9.4 KiB
Python
280 lines
9.4 KiB
Python
"""
|
|
@package tests.database.test_orm
|
|
unit tests for pmsco.database.orm
|
|
|
|
the purpose of these tests is to help debugging the code.
|
|
|
|
to run the tests, change to the directory which contains the tests directory, and execute =nosetests=.
|
|
|
|
@pre nose must be installed (python-nose package on Debian).
|
|
|
|
@author Matthias Muntwiler, matthias.muntwiler@psi.ch
|
|
|
|
@copyright (c) 2021 by Paul Scherrer Institut @n
|
|
Licensed under the Apache License, Version 2.0 (the "License"); @n
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
"""
|
|
|
|
import unittest
|
|
import pmsco.database.access as db
|
|
import pmsco.database.orm as orm
|
|
import pmsco.database.util as util
|
|
import pmsco.dispatch as dispatch
|
|
|
|
|
|
class TestDatabase(unittest.TestCase):
|
|
def setUp(self):
|
|
self.db = db.DatabaseAccess()
|
|
self.db.connect(":memory:")
|
|
|
|
def tearDown(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def setup_class(cls):
|
|
# before any methods in this class
|
|
pass
|
|
|
|
@classmethod
|
|
def teardown_class(cls):
|
|
# teardown_class() after any methods in this class
|
|
pass
|
|
|
|
def test_orm_1(self):
|
|
with self.db.session() as session:
|
|
prj = orm.Project(name="test1", code=__file__)
|
|
session.add(prj)
|
|
job = orm.Job(name="test_database")
|
|
job.project = prj
|
|
session.add(job)
|
|
tag1 = orm.Tag(key="phase")
|
|
tag2 = orm.Tag(key="scatter")
|
|
session.add_all([tag1, tag2])
|
|
jt1 = orm.JobTag()
|
|
jt1.tag = tag1
|
|
jt1.job = job
|
|
jt1.value = 'phagen'
|
|
jt2 = orm.JobTag()
|
|
jt2.tag = tag2
|
|
jt2.job = job
|
|
jt2.value = 'edac'
|
|
session.commit()
|
|
|
|
qprj = session.query(orm.Project).filter_by(id=1).one()
|
|
self.assertEqual(prj.name, qprj.name)
|
|
qjob = session.query(orm.Job).filter_by(id=1).one()
|
|
self.assertEqual(job.name, qjob.name)
|
|
self.assertEqual(job.project.name, prj.name)
|
|
self.assertEqual(len(qprj.jobs), 1)
|
|
self.assertEqual(len(qjob.job_tags), 2)
|
|
self.assertEqual(qjob.tags['phase'], 'phagen')
|
|
self.assertEqual(qjob.tags['scatter'], 'edac')
|
|
|
|
def test_orm_2(self):
|
|
with self.db.session() as session:
|
|
prj = orm.Project(name="project 1", code=__file__)
|
|
session.add(prj)
|
|
|
|
job = orm.Job(name="job 1")
|
|
job.project = prj
|
|
session.add(job)
|
|
|
|
jt1 = orm.JobTag('phase', 'phagen')
|
|
session.add(jt1)
|
|
job.job_tags[jt1.tag_key] = jt1
|
|
job.tags['scatter'] = 'edac'
|
|
|
|
mod = orm.Model(model=1111, gen=111, particle=11)
|
|
session.add(mod)
|
|
|
|
pv1 = orm.ParamValue(key='dAB', value=123.456, delta=7.543)
|
|
session.add(pv1)
|
|
mod.param_values[pv1.param_key] = pv1
|
|
mod.values['dBC'] = 234.567
|
|
|
|
cid = dispatch.CalcID(1111, 2, 3, 4, 5)
|
|
res = orm.Result(calc_id=cid, rfac=0.123)
|
|
res.model = mod
|
|
session.add(res)
|
|
|
|
session.commit()
|
|
|
|
qprj = session.query(orm.Project).filter_by(id=1).one()
|
|
self.assertEqual(qprj.name, prj.name)
|
|
self.assertEqual(len(qprj.jobs), 1)
|
|
job_names = [k for k in qprj.jobs.keys()]
|
|
self.assertEqual(job_names[0], job.name)
|
|
self.assertEqual(qprj.jobs[job.name], job)
|
|
|
|
qjob = session.query(orm.Job).filter_by(id=1).one()
|
|
self.assertEqual(qjob.name, job.name)
|
|
self.assertEqual(qjob.project.name, prj.name)
|
|
self.assertEqual(len(qjob.job_tags), 2)
|
|
self.assertEqual(qjob.job_tags['phase'].value, 'phagen')
|
|
self.assertEqual(qjob.job_tags['scatter'].value, 'edac')
|
|
self.assertEqual(len(qjob.tags), 2)
|
|
self.assertEqual(qjob.tags['phase'], 'phagen')
|
|
self.assertEqual(qjob.tags['scatter'], 'edac')
|
|
|
|
qmod = session.query(orm.Model).filter_by(id=1).one()
|
|
self.assertEqual(qmod.model, mod.model)
|
|
self.assertEqual(len(qmod.param_values), 2)
|
|
self.assertEqual(qmod.values['dAB'], 123.456)
|
|
self.assertEqual(qmod.deltas['dAB'], 7.543)
|
|
self.assertEqual(qmod.values['dBC'], 234.567)
|
|
|
|
self.assertEqual(len(qmod.results), 1)
|
|
self.assertEqual(qmod.results[0].rfac, 0.123)
|
|
|
|
def test_job_tags(self):
|
|
with self.db.session() as session:
|
|
prj = orm.Project(name="project 1", code=__file__)
|
|
session.add(prj)
|
|
|
|
job1 = orm.Job(name="job 1")
|
|
job1.project = prj
|
|
session.add(job1)
|
|
job2 = orm.Job(name="job 2")
|
|
job2.project = prj
|
|
session.add(job2)
|
|
|
|
job1.tags['color'] = 'blue'
|
|
job1.tags['shape'] = 'round'
|
|
session.flush()
|
|
job2.tags['color'] = 'red'
|
|
job1.tags['color'] = 'green'
|
|
|
|
session.commit()
|
|
|
|
qjob1 = session.query(orm.Job).filter_by(name='job 1').one()
|
|
self.assertEqual(qjob1.tags['color'], 'green')
|
|
qjob2 = session.query(orm.Job).filter_by(name='job 2').one()
|
|
self.assertEqual(qjob2.tags['color'], 'red')
|
|
|
|
def test_job_jobtags(self):
|
|
with self.db.session() as session:
|
|
prj = orm.Project(name="project 1", code=__file__)
|
|
session.add(prj)
|
|
|
|
job1 = orm.Job(name="job 1")
|
|
job1.project = prj
|
|
session.add(job1)
|
|
job2 = orm.Job(name="job 2")
|
|
job2.project = prj
|
|
session.add(job2)
|
|
|
|
jt1 = orm.JobTag('color', 'blue')
|
|
job1.job_tags[jt1.tag_key] = jt1
|
|
session.flush()
|
|
jt2 = orm.JobTag('color', 'red')
|
|
job2.job_tags[jt2.tag_key] = jt2
|
|
|
|
session.commit()
|
|
|
|
qjob1 = session.query(orm.Job).filter_by(name='job 1').one()
|
|
self.assertIsInstance(qjob1.job_tags['color'], orm.JobTag)
|
|
self.assertEqual(qjob1.job_tags['color'].value, 'blue')
|
|
qjob2 = session.query(orm.Job).filter_by(name='job 2').one()
|
|
self.assertIsInstance(qjob2.job_tags['color'], orm.JobTag)
|
|
self.assertEqual(qjob2.job_tags['color'].value, 'red')
|
|
|
|
def test_param_values(self):
|
|
with self.db.session() as session:
|
|
prj = orm.Project(name="project 1", code=__file__)
|
|
session.add(prj)
|
|
job = orm.Job(name="job 1")
|
|
job.project = prj
|
|
session.add(job)
|
|
|
|
mod1 = orm.Model(model=1, gen=11, particle=111)
|
|
session.add(mod1)
|
|
mod2 = orm.Model(model=2, gen=22, particle=222)
|
|
session.add(mod2)
|
|
|
|
mod1.values['dBC'] = 234.567
|
|
# note: this flush is necessary before accessing the same param in another model
|
|
session.flush()
|
|
mod2.values['dBC'] = 345.678
|
|
|
|
session.commit()
|
|
|
|
qmod1 = session.query(orm.Model).filter_by(model=1).one()
|
|
self.assertEqual(qmod1.values['dBC'], 234.567)
|
|
qmod2 = session.query(orm.Model).filter_by(model=2).one()
|
|
self.assertEqual(qmod2.values['dBC'], 345.678)
|
|
|
|
def test_filter_job(self):
|
|
"""
|
|
test sqlalchemy filter syntax
|
|
|
|
@return: None
|
|
"""
|
|
with self.db.session() as session:
|
|
p1 = orm.Project(name="p1")
|
|
p2 = orm.Project(name="p2")
|
|
j11 = orm.Job(name="j1")
|
|
j11.project = p1
|
|
j12 = orm.Job(name="j2")
|
|
j12.project = p1
|
|
j21 = orm.Job(name="j1")
|
|
j21.project = p2
|
|
j22 = orm.Job(name="j2")
|
|
j22.project = p2
|
|
session.add_all([p1, p2, j11, j12, j21, j22])
|
|
session.commit()
|
|
|
|
q1 = session.query(orm.Job).join(orm.Project)
|
|
q1 = q1.filter(orm.Project.name == 'p1')
|
|
q1 = q1.filter(orm.Job.name == 'j1')
|
|
jobs1 = q1.all()
|
|
|
|
sql = """
|
|
select Projects.name project_name, Jobs.name job_name
|
|
from Projects join Jobs on Projects.id = Jobs.project_id
|
|
where Jobs.name = 'j1' and Projects.name = 'p1'
|
|
"""
|
|
jobs2 = session.execute(sql)
|
|
|
|
n = 0
|
|
for j in jobs2:
|
|
self.assertEqual(j.project_name, 'p1')
|
|
self.assertEqual(j.job_name, 'j1')
|
|
n += 1
|
|
self.assertEqual(n, 1)
|
|
|
|
for j in jobs1:
|
|
self.assertEqual(j.project.name, 'p1')
|
|
self.assertEqual(j.name, 'j1')
|
|
self.assertEqual(len(jobs1), 1)
|
|
|
|
def test_filter_in(self):
|
|
"""
|
|
test sqlalchemy filter syntax: in_ operator
|
|
|
|
@return: None
|
|
"""
|
|
with self.db.session() as session:
|
|
p1 = orm.Project(name="p1")
|
|
p2 = orm.Project(name="p2")
|
|
j11 = orm.Job(name="j1")
|
|
j11.project = p1
|
|
j12 = orm.Job(name="j2")
|
|
j12.project = p1
|
|
j21 = orm.Job(name="j1")
|
|
j21.project = p2
|
|
j22 = orm.Job(name="j2")
|
|
j22.project = p2
|
|
session.add_all([p1, p2, j11, j12, j21, j22])
|
|
session.commit()
|
|
|
|
q1 = session.query(orm.Job)
|
|
q1 = q1.filter(orm.Job.id.in_([2, 3, 7]))
|
|
jobs1 = q1.all()
|
|
self.assertEqual(len(jobs1), 2)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|