|
3 | 3 | import numpy as np |
4 | 4 | import os |
5 | 5 | import re |
| 6 | +import sys |
6 | 7 | from collections import OrderedDict |
7 | 8 |
|
8 | 9 | from conftest import skipif |
|
23 | 24 | from devito.petsc.logging import PetscSummary |
24 | 25 | from devito.petsc.solver_parameters import linear_solve_defaults |
25 | 26 |
|
| 27 | +@pytest.fixture(scope='session') |
| 28 | +def command_line(): |
| 29 | + # one prefix per test |
| 30 | + prefix = ('d17weqroegn', 'riabfodkj') |
| 31 | + |
| 32 | + petsc_option = ( |
| 33 | + ('ksp_rtol',), |
| 34 | + ('ksp_rtol','ksp_atol') |
| 35 | + ) |
| 36 | + value = ( |
| 37 | + ('1e-8',), |
| 38 | + ('1e-11','1e-15'), |
| 39 | + ) |
| 40 | + argv = [] |
| 41 | + |
| 42 | + expected = {} |
| 43 | + for p, opt, val in zip(prefix, petsc_option, value, strict=True): |
| 44 | + for o, v in zip(opt, val, strict=True): |
| 45 | + argv.extend([f'-{p}_{o}', v]) |
| 46 | + expected[p] = zip(opt,val) |
| 47 | + return argv, expected |
| 48 | + |
26 | 49 |
|
27 | 50 | @pytest.fixture(scope='session', autouse=True) |
28 | | -def petsc_initialization(): |
| 51 | +def petsc_initialization(command_line): |
| 52 | + argv, _ = command_line |
29 | 53 | # TODO: Temporary workaround until PETSc is automatically |
30 | 54 | # initialized |
31 | 55 | configuration['compiler'] = 'custom' |
32 | 56 | os.environ['CC'] = 'mpicc' |
33 | | - # PetscInitialize(argv) |
34 | | - PetscInitialize() |
| 57 | + PetscInitialize(argv) |
35 | 58 |
|
36 | 59 |
|
37 | 60 | @skipif('petsc') |
@@ -1710,26 +1733,48 @@ def test_multiple_operators(self, log_level): |
1710 | 1733 | # TODO: Add test to check that the command line args override anything set |
1711 | 1734 | # in the solver_parameters dictionary |
1712 | 1735 |
|
1713 | | - # @skipif('petsc') |
1714 | | - # def test_command_line_priority(self): |
1715 | | - # """ |
1716 | | - # Test solver parameters specifed via the command line |
1717 | | - # take precedence over those set in the solver_parameters |
1718 | | - # dictionary. |
1719 | | - # """ |
1720 | | - # prefix = 'd17weqroegn' |
1721 | | - |
1722 | | - # solver1 = PETScSolve( |
1723 | | - # self.eq1, target=self.e, solver_parameters={ |
1724 | | - # 'ksp_rtol': '1e-9', |
1725 | | - # 'snes_view': None} |
1726 | | - # ) |
1727 | | - # with switchconfig(language='petsc'): |
1728 | | - # op = Operator(solver1) |
1729 | | - # op.apply() |
1730 | | - |
1731 | | - # # Check that the command line option specifying the ksp_rtol took |
1732 | | - # # priorty over the solver |
| 1736 | + @skipif('petsc') |
| 1737 | + def test_command_line_priority_1(self, command_line): |
| 1738 | + """ |
| 1739 | + Test solver parameters specifed via the command line |
| 1740 | + take precedence over those set in the solver_parameters |
| 1741 | + dictionary. |
| 1742 | + """ |
| 1743 | + prefix = 'd17weqroegn' |
| 1744 | + _, expected = command_line |
| 1745 | + |
| 1746 | + solver1 = PETScSolve( |
| 1747 | + self.eq1, target=self.e, |
| 1748 | + options_prefix=prefix |
| 1749 | + ) |
| 1750 | + with switchconfig(language='petsc', log_level='DEBUG'): |
| 1751 | + op = Operator(solver1) |
| 1752 | + summary = op.apply() |
| 1753 | + |
| 1754 | + petsc_summary = summary.petsc |
| 1755 | + entry = petsc_summary.get_entry('section0', prefix) |
| 1756 | + for opt, val in expected[prefix]: |
| 1757 | + assert str(entry.KSPGetTolerances[opt.removeprefix('ksp_')]) == val |
| 1758 | + |
| 1759 | + @skipif('petsc') |
| 1760 | + def test_command_line_priority_2(self, command_line): |
| 1761 | + """ |
| 1762 | + """ |
| 1763 | + prefix = 'riabfodkj' |
| 1764 | + _, expected = command_line |
| 1765 | + |
| 1766 | + solver1 = PETScSolve( |
| 1767 | + self.eq1, target=self.e, |
| 1768 | + options_prefix=prefix |
| 1769 | + ) |
| 1770 | + with switchconfig(language='petsc', log_level='DEBUG'): |
| 1771 | + op = Operator(solver1) |
| 1772 | + summary = op.apply() |
| 1773 | + |
| 1774 | + petsc_summary = summary.petsc |
| 1775 | + entry = petsc_summary.get_entry('section0', prefix) |
| 1776 | + for opt, val in expected[prefix]: |
| 1777 | + assert str(entry.KSPGetTolerances[opt.removeprefix('ksp_')]) == val |
1733 | 1778 |
|
1734 | 1779 |
|
1735 | 1780 | class TestHashing: |
|
0 commit comments