update public distribution
based on internal repository c9a2ac8 2019-01-03 16:04:57 +0100 tagged rev-master-2.0.0
This commit is contained in:
@ -17,9 +17,14 @@ Licensed under the Apache License, Version 2.0 (the "License"); @n
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
"""
|
||||
|
||||
import unittest
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from io import BytesIO
|
||||
import math
|
||||
import numpy as np
|
||||
import unittest
|
||||
import pmsco.cluster as mc
|
||||
|
||||
|
||||
@ -153,7 +158,7 @@ class TestClusterFunctions(unittest.TestCase):
|
||||
self.assertEqual(2, clu.get_emitter_count())
|
||||
result = clu.get_emitters()
|
||||
expect = [(0., 0., 0., 1), (1., 0., 1., 10)]
|
||||
self.assertItemsEqual(expect, result)
|
||||
self.assertEqual(expect, result)
|
||||
|
||||
def test_get_z_layers(self):
|
||||
clu = mc.Cluster()
|
||||
@ -288,13 +293,21 @@ class TestClusterFunctions(unittest.TestCase):
|
||||
v_lat3 = np.asarray([0, 0, 1])
|
||||
clu.add_bulk(7, v_pos, v_lat1, v_lat2, v_lat3)
|
||||
clu.set_emitter(pos=v_pos)
|
||||
clu.trim_cylinder(2.3, 4.2)
|
||||
r0 = 2.3
|
||||
z0 = 4.2
|
||||
clu.trim_cylinder(r0, z0)
|
||||
self.assertEqual(clu.data.dtype, clu.dtype)
|
||||
self.assertEqual(clu.data.shape[0], 21 * 5)
|
||||
self.assertEqual(clu.data[1]['i'], 2)
|
||||
self.assertEqual(clu.data[1]['s'], 'N')
|
||||
self.assertEqual(clu.data[1]['t'], 7)
|
||||
self.assertEqual(clu.get_emitter_count(), 1)
|
||||
n_low = np.sum(clu.data['z'] < -z0)
|
||||
self.assertEqual(0, n_low)
|
||||
n_high = np.sum(clu.data['z'] > z0)
|
||||
self.assertEqual(0, n_high)
|
||||
n_out = np.sum(clu.data['x']**2 + clu.data['y']**2 > r0**2)
|
||||
self.assertEqual(0, n_out)
|
||||
|
||||
def test_trim_sphere(self):
|
||||
clu = mc.Cluster()
|
||||
@ -305,13 +318,39 @@ class TestClusterFunctions(unittest.TestCase):
|
||||
v_lat3 = np.asarray([0, 0, 1])
|
||||
clu.add_bulk(7, v_pos, v_lat1, v_lat2, v_lat3)
|
||||
clu.set_emitter(pos=v_pos)
|
||||
clu.trim_sphere(2.3)
|
||||
r0 = 2.3
|
||||
clu.trim_sphere(r0)
|
||||
self.assertEqual(clu.data.dtype, clu.dtype)
|
||||
self.assertEqual(clu.data.shape[0], 39)
|
||||
self.assertEqual(clu.data[1]['i'], 2)
|
||||
self.assertEqual(clu.data[1]['s'], 'N')
|
||||
self.assertEqual(clu.data[1]['t'], 7)
|
||||
self.assertEqual(clu.get_emitter_count(), 1)
|
||||
n_out = np.sum(clu.data['x']**2 + clu.data['y']**2 + clu.data['z'] > r0**2)
|
||||
self.assertEqual(0, n_out)
|
||||
|
||||
def test_trim_paraboloid(self):
|
||||
clu = mc.Cluster()
|
||||
clu.set_rmax(10.0)
|
||||
v_pos = np.asarray([0, 0, 0])
|
||||
v_lat1 = np.asarray([1, 0, 0])
|
||||
v_lat2 = np.asarray([0, 1, 0])
|
||||
v_lat3 = np.asarray([0, 0, 1])
|
||||
clu.add_bulk(7, v_pos, v_lat1, v_lat2, v_lat3)
|
||||
clu.set_emitter(pos=v_pos)
|
||||
r0 = 3.5
|
||||
z0 = -2.3
|
||||
clu.trim_paraboloid(r0, z0)
|
||||
self.assertEqual(clu.data.dtype, clu.dtype)
|
||||
self.assertEqual(63, clu.data.shape[0])
|
||||
self.assertEqual(2, clu.data[1]['i'])
|
||||
self.assertEqual('N', clu.data[1]['s'])
|
||||
self.assertEqual(7, clu.data[1]['t'])
|
||||
self.assertEqual(1, clu.get_emitter_count())
|
||||
n_low = np.sum(clu.data['z'] < z0)
|
||||
self.assertEqual(0, n_low)
|
||||
n_out = np.sum(clu.data['x']**2 + clu.data['y']**2 > r0**2)
|
||||
self.assertEqual(0, n_out)
|
||||
|
||||
def test_trim_slab(self):
|
||||
clu = self.create_cube()
|
||||
@ -319,3 +358,21 @@ class TestClusterFunctions(unittest.TestCase):
|
||||
self.assertEqual(clu.data.dtype, clu.dtype)
|
||||
self.assertEqual(clu.data.shape[0], 9 * 2)
|
||||
self.assertEqual(clu.get_emitter_count(), 1)
|
||||
|
||||
def test_save_to_file(self):
|
||||
clu = self.create_cube()
|
||||
f = BytesIO()
|
||||
pos = np.asarray((-1, -1, 0))
|
||||
clu.set_emitter(pos=pos)
|
||||
clu.save_to_file(f, mc.FMT_XYZ, "qwerty", emitters_only=True)
|
||||
f.seek(0)
|
||||
line = f.readline()
|
||||
self.assertEqual(line, b"2\n", b"line 1: " + line)
|
||||
line = f.readline()
|
||||
self.assertEqual(line, b"qwerty\n", b"line 2: " + line)
|
||||
line = f.readline()
|
||||
self.assertRegexpMatches(line, b"H +[0.]+ +[0.]+ +[0.]+", b"line 3: " + line)
|
||||
line = f.readline()
|
||||
self.assertRegexpMatches(line, b"Si +[01.-]+ +[01.-]+ +[0.]+", b"line 4: " + line)
|
||||
line = f.readline()
|
||||
self.assertEqual(line, b"", b"end of file")
|
||||
|
Reference in New Issue
Block a user