diff --git a/vaspy/atomco.py b/vaspy/atomco.py index 464015c..11f527a 100644 --- a/vaspy/atomco.py +++ b/vaspy/atomco.py @@ -173,7 +173,7 @@ def get_poscar_content(self, **kwargs): tf = self.tf except AttributeError: # Initialize tf with 'T's. - default_tf = np.full(self.data.shape, 'T', dtype=np.str) + default_tf = np.full(self.data.shape, 'T', dtype=str) tf = kwargs.get("tf", default_tf) data_tf = '' if coord_type == 'direct': diff --git a/vaspy/tests/electro_test.py b/vaspy/tests/electro_test.py new file mode 100644 index 0000000..63e01ee --- /dev/null +++ b/vaspy/tests/electro_test.py @@ -0,0 +1,120 @@ +# -*- coding:utf-8 -*- +''' +Unit tests for vaspy.electro module. +''' + +import unittest +import os +import copy + +import numpy as np + +from ..electro import DosX, ElfCar, ChgCar +from . import path + + +class DosXTest(unittest.TestCase): + + def setUp(self): + self.filename = os.path.join(path, "DOS_SUM") + + def test_load(self): + dosx = DosX(self.filename) + self.assertIsNotNone(dosx.data) + self.assertGreater(dosx.data.shape[0], 0) + + def test_reset_data(self): + dosx = DosX(self.filename) + dosx.reset_data() + self.assertTrue(np.all(dosx.data[:, 1:] == 0.0)) + + def test_add(self): + dosx1 = DosX(self.filename) + dosx2 = DosX(self.filename) + dos_sum = dosx1 + dosx2 + self.assertEqual(dos_sum.filename, "DOS_SUM") + + def test_deepcopy(self): + dosx = DosX(self.filename) + dosx_copy = copy.deepcopy(dosx) + self.assertTrue(np.all(dosx.data == dosx_copy.data)) + self.assertIsNot(dosx.data, dosx_copy.data) + + def test_tofile(self): + dosx = DosX(self.filename) + outfile = os.path.join(path, "_test_dos_output.txt") + try: + dosx.tofile(filename=outfile) + self.assertTrue(os.path.exists(outfile)) + finally: + if os.path.exists(outfile): + os.remove(outfile) + + def test_get_dband_center(self): + dosx = DosX(self.filename) + dbc = dosx.get_dband_center(d_cols=(5, 10)) + self.assertIsNotNone(dbc) + self.assertEqual(dosx.dband_center, dbc) + + def test_get_dband_center_int_arg(self): + dosx = DosX(self.filename) + dbc = dosx.get_dband_center(d_cols=5) + self.assertIsNotNone(dbc) + + def test_add_mismatched_energy_raises(self): + dosx1 = DosX(self.filename) + dosx2 = DosX(self.filename) + dosx2.data[0, 0] = 999.0 + with self.assertRaises(ValueError): + dosx1 + dosx2 + + +class ElfCarTest(unittest.TestCase): + + def setUp(self): + self.filename = os.path.join(path, "ELFCAR") + + def test_load(self): + elf = ElfCar(self.filename) + self.assertIsNotNone(elf.elf_data) + self.assertEqual(len(elf.elf_data.shape), 3) + self.assertIsNotNone(elf.grid) + + def test_expand_data(self): + elf = ElfCar(self.filename) + expanded_data, expanded_grid = elf.expand_data(elf.elf_data, elf.grid, (2, 1, 1)) + self.assertEqual(expanded_data.shape[0], elf.elf_data.shape[0] * 2) + self.assertEqual(expanded_grid[0], elf.grid[0] * 2) + + def test_contour_bad_distance(self): + elf = ElfCar(self.filename) + with self.assertRaises(ValueError): + elf.plot_contour(distance=1.5) + + def test_contour_bad_show_mode(self): + elf = ElfCar(self.filename) + with self.assertRaises(ValueError): + elf.plot_contour(show_mode='bad') + + def test_contour_cut_x(self): + elf = ElfCar(self.filename) + elf.plot_contour(axis_cut='x', show_mode='save') + + def test_contour_cut_y(self): + elf = ElfCar(self.filename) + elf.plot_contour(axis_cut='y', show_mode='save') + + def test_contour_cut_z(self): + elf = ElfCar(self.filename) + elf.plot_contour(axis_cut='z', show_mode='save') + + +class ChgCarTest(unittest.TestCase): + + def setUp(self): + self.filename = os.path.join(path, "ELFCAR") + + def test_init(self): + chg = ChgCar(self.filename) + self.assertIsNotNone(chg.elf_data) + self.assertIsNotNone(chg.grid) diff --git a/vaspy/tests/elements_test.py b/vaspy/tests/elements_test.py new file mode 100644 index 0000000..4c7bdce --- /dev/null +++ b/vaspy/tests/elements_test.py @@ -0,0 +1,28 @@ +# -*- coding:utf-8 -*- +''' +Unit tests for vaspy.elements module. +''' + +import unittest + +from .. import elements + + +class ElementsTest(unittest.TestCase): + + def test_C12(self): + self.assertAlmostEqual(elements.C12, 1.99264648e-26) + + def test_amu(self): + self.assertAlmostEqual(elements.amu, 1.66053904e-27) + + def test_chem_elements_has_H(self): + self.assertIn('H', elements.chem_elements) + self.assertEqual(elements.chem_elements['H']['index'], 1) + + def test_chem_elements_has_Ni(self): + self.assertIn('Ni', elements.chem_elements) + self.assertEqual(elements.chem_elements['Ni']['index'], 28) + + def test_chem_elements_count(self): + self.assertEqual(len(elements.chem_elements), 9) diff --git a/vaspy/tests/errors_test.py b/vaspy/tests/errors_test.py new file mode 100644 index 0000000..2ab4e20 --- /dev/null +++ b/vaspy/tests/errors_test.py @@ -0,0 +1,30 @@ +# -*- coding:utf-8 -*- +''' +Unit tests for vaspy.errors module. +''' + +import unittest + +from ..errors import CarfileValueError, UnmatchedDataShape + + +class CarfileValueErrorTest(unittest.TestCase): + + def test_raise(self): + with self.assertRaises(CarfileValueError): + raise CarfileValueError("test error") + + def test_message(self): + err = CarfileValueError("bad value") + self.assertEqual(str(err), "bad value") + + +class UnmatchedDataShapeTest(unittest.TestCase): + + def test_raise(self): + with self.assertRaises(UnmatchedDataShape): + raise UnmatchedDataShape("shape mismatch") + + def test_message(self): + err = UnmatchedDataShape("shape mismatch") + self.assertEqual(str(err), "shape mismatch") diff --git a/vaspy/tests/functions_test.py b/vaspy/tests/functions_test.py new file mode 100644 index 0000000..b4b6c4c --- /dev/null +++ b/vaspy/tests/functions_test.py @@ -0,0 +1,114 @@ +# -*- coding:utf-8 -*- +''' +Unit tests for vaspy.functions module. +''' + +import unittest +import numpy as np + +from ..functions import (str2list, line2list, array2str, + combine_atomco_dict, atomdict2str, + get_combinations, get_angle) + + +class Str2listTest(unittest.TestCase): + + def test_str2list(self): + result = str2list(' 1.0 2.0 3.0 ') + self.assertListEqual(result, ['1.0', '2.0', '3.0']) + + def test_str2list_empty(self): + result = str2list('') + self.assertEqual(result, []) + + +class Line2listTest(unittest.TestCase): + + def test_line2list_float(self): + result = line2list('1.0 2.0 3.0', dtype=float) + self.assertListEqual(result, [1.0, 2.0, 3.0]) + + def test_line2list_int(self): + result = line2list('10 20 30', dtype=int) + self.assertListEqual(result, [10, 20, 30]) + + def test_line2list_str(self): + result = line2list('a b c', dtype=str) + self.assertListEqual(result, ['a', 'b', 'c']) + + def test_line2list_custom_field(self): + result = line2list('1.0,2.0,3.0', field=',', dtype=float) + self.assertListEqual(result, [1.0, 2.0, 3.0]) + + def test_line2list_empty_elements(self): + result = line2list(' 1.0 2.0 ', dtype=float) + self.assertListEqual(result, [1.0, 2.0]) + + def test_line2list_type_error(self): + with self.assertRaises(TypeError): + line2list('1.0 2.0', dtype=3.14) + + +class Array2strTest(unittest.TestCase): + + def test_array2str(self): + arr = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + result = array2str(arr) + self.assertIn('1.0000000000000000', result) + self.assertIn('2.0000000000000000', result) + self.assertEqual(result.count('\n'), 2) + + +class CombineAtomcoDictTest(unittest.TestCase): + + def test_combine_disjoint(self): + a = {'C': [[1.0, 2.0, 3.0]]} + b = {'O': [[4.0, 5.0, 6.0]]} + result = combine_atomco_dict(a, b) + self.assertEqual(set(result.keys()), {'C', 'O'}) + + def test_combine_overlap(self): + a = {'C': [[1.0, 2.0, 3.0]]} + b = {'C': [[4.0, 5.0, 6.0]]} + result = combine_atomco_dict(a, b) + self.assertEqual(len(result['C']), 2) + + def test_combine_empty(self): + result = combine_atomco_dict({}, {}) + self.assertEqual(result, {}) + + +class Atomdict2strTest(unittest.TestCase): + + def test_atomdict2str(self): + d = {'C': [[2.01115823704755, 2.33265069974919, 10.54948252493041]], + 'Co': [[0.28355818414485, 2.31976779057375, 2.34330019781397], + [2.76900337448991, 0.88479534087197, 2.34330019781397]]} + result = atomdict2str(d, ['C', 'Co']) + self.assertIn('C', result) + self.assertIn('Co', result) + self.assertEqual(result.count('\n'), 3) + + +class GetCombinationsTest(unittest.TestCase): + + def test_get_combinations(self): + result = get_combinations(3, 4, 5) + self.assertIsInstance(result, np.ndarray) + + +class GetAngleTest(unittest.TestCase): + + def test_get_angle_90(self): + v1 = np.array([1.0, 0.0, 0.0]) + v2 = np.array([0.0, 1.0, 0.0]) + self.assertAlmostEqual(get_angle(v1, v2), 90.0) + + def test_get_angle_0(self): + v1 = np.array([1.0, 0.0, 0.0]) + self.assertAlmostEqual(get_angle(v1, v1), 0.0) + + def test_get_angle_180(self): + v1 = np.array([1.0, 0.0, 0.0]) + v2 = np.array([-1.0, 0.0, 0.0]) + self.assertAlmostEqual(get_angle(v1, v2), 180.0) diff --git a/vaspy/tests/plotter_test.py b/vaspy/tests/plotter_test.py new file mode 100644 index 0000000..2b480bc --- /dev/null +++ b/vaspy/tests/plotter_test.py @@ -0,0 +1,28 @@ +# -*- coding:utf-8 -*- +''' +Unit tests for vaspy.plotter module. +''' + +import unittest +import os + +from ..plotter import DataPlotter +from . import path + + +class DataPlotterTest(unittest.TestCase): + + def setUp(self): + self.filename = os.path.join(path, "PLOTCON") + + def test_load(self): + plotter = DataPlotter(self.filename) + self.assertIsNotNone(plotter.data) + self.assertGreater(plotter.data.shape[0], 0) + self.assertGreater(plotter.data.shape[1], 0) + + def test_attributes(self): + plotter = DataPlotter(self.filename) + self.assertEqual(plotter.filename, self.filename) + self.assertEqual(plotter.field, ' ') + self.assertEqual(plotter.dtype, float) diff --git a/vaspy/tests/test_all.py b/vaspy/tests/test_all.py index 08be7b6..16ebde4 100644 --- a/vaspy/tests/test_all.py +++ b/vaspy/tests/test_all.py @@ -13,4 +13,14 @@ from .cif_test import CifFileTest from .ani_test import AniFileTest from .xdatcar_test import XdatCarTest +from .functions_test import (Str2listTest, Line2listTest, Array2strTest, + CombineAtomcoDictTest, Atomdict2strTest, + GetCombinationsTest, GetAngleTest) +from .plotter_test import DataPlotterTest +from .electro_test import DosXTest, ElfCarTest, ChgCarTest +from .elements_test import ElementsTest +from .errors_test import CarfileValueErrorTest, UnmatchedDataShapeTest + +if __name__ == '__main__': + unittest.main()