Skip to content

Commit 294687b

Browse files
committed
compiler: Create dummy op for petscgetargs
1 parent d7da286 commit 294687b

9 files changed

Lines changed: 102 additions & 29 deletions

File tree

devito/petsc/iet/passes.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
PointerIS, Mat, CallbackVec, Vec, CallbackMat, SNES,
1717
DummyArg, PetscInt, PointerDM, PointerMat, MatReuse,
1818
CallbackPointerIS, CallbackPointerDM, JacobianStruct,
19-
SubMatrixStruct, Initialize, Finalize, ArgvSymbol)
19+
SubMatrixStruct, Initialize, Finalize, ArgvSymbol,
20+
GetArgs, ArgvSymbolPtr, ArgcPtr)
2021
from devito.petsc.types.macros import petsc_func_begin_user, Null
2122
from devito.petsc.iet.nodes import PetscMetaData
2223
from devito.petsc.utils import core_metadata, petsc_languages
@@ -51,6 +52,9 @@ def lower_petsc(iet, **kwargs):
5152
if any(filter(lambda i: isinstance(i.expr.rhs, Finalize), data)):
5253
return finalize(iet), core_metadata()
5354

55+
if any(filter(lambda i: isinstance(i.expr.rhs, GetArgs), data)):
56+
return get_args(iet), core_metadata()
57+
5458
unique_grids = {i.expr.rhs.grid for (i,) in inject_solve_mapper.values()}
5559
# Assumption is that all solves are on the same grid
5660
if len(unique_grids) > 1:
@@ -73,8 +77,6 @@ def lower_petsc(iet, **kwargs):
7377
prefixes = [d.expr.rhs.user_prefix for d in data if d.expr.rhs.user_prefix]
7478
duplicates = {p for p in prefixes if prefixes.count(p) > 1}
7579

76-
# TODO: Avoid the other exception raised - think due to exception being
77-
# raised inside the @iet_pass?
7880
if duplicates:
7981
dup_list = ", ".join(repr(p) for p in sorted(duplicates))
8082
raise ValueError(
@@ -136,6 +138,18 @@ def finalize(iet):
136138
return iet._rebuild(body=finalize_body)
137139

138140

141+
def get_args(iet):
142+
argc = ArgcPtr(name='argc', dtype=np.int32)
143+
argv = ArgvSymbolPtr(name='argv')
144+
145+
body = petsc_call('PetscGetArgs', [argc, argv])
146+
body = CallableBody(
147+
body=(petsc_func_begin_user, body),
148+
retstmt=(Call('PetscFunctionReturn', arguments=[0]),)
149+
)
150+
return iet._rebuild(body=body)
151+
152+
139153
def make_core_petsc_calls(objs, comm):
140154
call_mpi = petsc_call_mpi('MPI_Comm_size', [comm, Byref(objs['size'])])
141155
return call_mpi, BlankLine

devito/petsc/iet/routines.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import sys
21
from collections import OrderedDict
32
from functools import cached_property
43
import math
@@ -23,6 +22,7 @@
2322
VecScatter, DMCast, JacobianStruct, SubMatrixStruct,
2423
CallbackDM)
2524
from devito.petsc.types.macros import petsc_func_begin_user, Null
25+
# from devito.petsc.initialize import PetscGetArgs
2626

2727

2828
class CBBuilder:
@@ -149,14 +149,16 @@ def _make_options_callback(self):
149149

150150
for k, v in params.items():
151151
option = f'-{prefix}{k}'
152-
# if option in sys.argv:
153-
# TODO: Pre-build the KSPGetArgs operator and run it here
154-
# to drop the global _petsc_clargs
155-
# tmp = petsc_get_args_op.apply()
152+
153+
# TODO: drop the global variable _petsc_clargs..
154+
# from devito.petsc.initialize import PetscGetArgs
155+
# PetscGetArgs()
156+
156157
import devito.petsc.initialize
157158
if option in devito.petsc.initialize._petsc_clargs:
158159
# Ensures that the command line args take priority
159160
continue
161+
160162
option_name = String(option)
161163
# For options without a value e.g `ksp_view`, pass Null
162164
option_value = Null if v is None else String(str(v))

devito/petsc/initialize.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,34 @@
11
import os
22
import sys
3-
from ctypes import POINTER, cast, c_char
3+
from ctypes import POINTER, cast, c_char, c_int, c_char_p, byref
44
import atexit
55

66
from devito import Operator, switchconfig
77
from devito.types import Symbol
88
from devito.types.equation import PetscEq
9-
from devito.petsc.types import Initialize, Finalize
9+
from devito.petsc.types import Initialize, Finalize, GetArgs
1010

1111
global _petsc_initialized
1212
_petsc_initialized = False
1313

1414

1515
global _petsc_clargs
1616

17+
18+
dummy = Symbol(name='d')
19+
20+
1721
def PetscInitialize(clargs=sys.argv):
1822
global _petsc_initialized
1923
global _petsc_clargs
2024

2125
if not _petsc_initialized:
2226
if clargs is not sys.argv:
2327
clargs = [sys.argv[0], *clargs]
24-
# TODO: drop this
28+
29+
# TODO: Drop this global variable
2530
_petsc_clargs = clargs
26-
dummy = Symbol(name='d')
31+
2732
# TODO: Potentially just use cgen + the compiler machinery in Devito
2833
# to generate these "dummy_ops" instead of using the Operator class.
2934
# This would prevent circular imports when initializing during import
@@ -46,6 +51,18 @@ def PetscInitialize(clargs=sys.argv):
4651
*map(lambda s: cast(s, POINTER(c_char)), argv_bytes)
4752
)
4853
op_init.apply(argc=len(clargs), argv=argv_pointer)
49-
5054
atexit.register(op_finalize.apply)
5155
_petsc_initialized = True
56+
57+
58+
with switchconfig(language='petsc'):
59+
op_get_args = Operator(
60+
[PetscEq(dummy, GetArgs(dummy))],
61+
name='kernel_get_args', opt='noop'
62+
)
63+
64+
65+
def PetscGetArgs():
66+
argc_ptr = c_int()
67+
argv_ptr = (POINTER(c_char_p))()
68+
op_get_args.apply(argc=byref(argc_ptr), argv=byref(argv_ptr))

devito/petsc/logging.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __repr__(self):
2323

2424
class PetscSummary(dict):
2525
"""
26+
# TODO: Actually print to screen when DEBUG of PERF is enabled
2627
A summary of PETSc statistics collected for all solver runs
2728
associated with a single operator during execution.
2829
"""

devito/petsc/types/object.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from ctypes import POINTER, c_char, c_char_p
1+
from ctypes import POINTER, c_char, c_char_p, c_int
22

33
from devito.tools import CustomDtype, dtype_to_ctype, as_tuple, CustomIntType
44
from devito.types import (LocalObject, LocalCompositeObject, ModuloDimension,
@@ -305,6 +305,18 @@ def _C_ctype(self):
305305
return POINTER(POINTER(c_char))
306306

307307

308+
class ArgvSymbolPtr(DataSymbol):
309+
@property
310+
def _C_ctype(self):
311+
return POINTER(POINTER(c_char_p))
312+
313+
314+
class ArgcPtr(DataSymbol):
315+
@property
316+
def _C_ctype(self):
317+
return POINTER(c_int)
318+
319+
308320
class CharPtr(DataSymbol):
309321
@property
310322
def _C_ctype(self):

devito/petsc/types/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ class Finalize(MetaData):
2828
pass
2929

3030

31+
class GetArgs(MetaData):
32+
pass
33+
34+
3135
class SolveExpr(MetaData):
3236
"""
3337
A symbolic expression passed through the Operator, containing the metadata

devito/petsc/utils.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import ctypes
33
from pathlib import Path
4-
from petsctools import get_petscvariables
54

65
from devito.tools import memoized_func, filter_ordered
76
from devito.types import Symbol, SteppingDimension
@@ -48,7 +47,32 @@ def core_metadata():
4847
}
4948

5049

51-
petsc_variables = get_petscvariables()
50+
@memoized_func
51+
def get_petsc_variables():
52+
"""
53+
Taken from https://www.firedrakeproject.org/_modules/firedrake/petsc.html
54+
Get a dict of PETSc environment variables from the file:
55+
$PETSC_DIR/$PETSC_ARCH/lib/petsc/conf/petscvariables
56+
"""
57+
try:
58+
petsc_dir = get_petsc_dir()
59+
except PetscOSError:
60+
petsc_variables = {}
61+
else:
62+
path = [petsc_dir[-1], 'lib', 'petsc', 'conf', 'petscvariables']
63+
variables_path = Path(*path)
64+
65+
with open(variables_path) as fh:
66+
# Split lines on first '=' (assignment)
67+
splitlines = (line.split("=", maxsplit=1) for line in fh.readlines())
68+
petsc_variables = {k.strip(): v.strip() for k, v in splitlines}
69+
70+
return petsc_variables
71+
72+
73+
petsc_variables = get_petsc_variables()
74+
# TODO: use petsctools get_petscvariables() instead?
75+
# petsc_variables = get_petscvariables()
5276

5377

5478
def get_petsc_type_mappings():

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ multidict<6.3
99
anytree>=2.4.3,<=2.13.0
1010
cloudpickle<3.1.2
1111
packaging<25.1
12-
petsctools
12+
petsctools==2025.1.dev0

tests/test_petsc.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44
import os
55
import re
6-
import sys
76
from collections import OrderedDict
87

98
from conftest import skipif
@@ -24,26 +23,26 @@
2423
from devito.petsc.logging import PetscSummary
2524
from devito.petsc.solver_parameters import linear_solve_defaults
2625

26+
2727
@pytest.fixture(scope='session')
2828
def command_line():
29-
# one prefix per test
29+
# One random prefix to use per test that "tests" the command line args
3030
prefix = ('d17weqroegn', 'riabfodkj')
3131

3232
petsc_option = (
3333
('ksp_rtol',),
34-
('ksp_rtol','ksp_atol')
34+
('ksp_rtol', 'ksp_atol')
3535
)
3636
value = (
37-
('1e-8',),
38-
('1e-11','1e-15'),
37+
(1e-8,),
38+
(1e-11, 1e-15),
3939
)
4040
argv = []
41-
4241
expected = {}
4342
for p, opt, val in zip(prefix, petsc_option, value, strict=True):
4443
for o, v in zip(opt, val, strict=True):
45-
argv.extend([f'-{p}_{o}', v])
46-
expected[p] = zip(opt,val)
44+
argv.extend([f'-{p}_{o}', str(v)])
45+
expected[p] = zip(opt, val)
4746
return argv, expected
4847

4948

@@ -1448,7 +1447,7 @@ def test_logging(self, log_level):
14481447
# the tolerances should match the default linear values.
14491448
tols = entry0.KSPGetTolerances
14501449
assert tols['rtol'] == linear_solve_defaults['ksp_rtol']
1451-
assert tols['abstol'] == linear_solve_defaults['ksp_atol']
1450+
assert tols['atol'] == linear_solve_defaults['ksp_atol']
14521451
assert tols['dtol'] == linear_solve_defaults['ksp_divtol']
14531452
assert tols['maxits'] == linear_solve_defaults['ksp_max_it']
14541453

@@ -1655,7 +1654,7 @@ def test_tolerances(self, log_level):
16551654
# Test that the tolerances have been set correctly and therefore
16561655
# appear as expected in the `PetscSummary`.
16571656
assert tolerances['rtol'] == params['ksp_rtol']
1658-
assert tolerances['abstol'] == params['ksp_atol']
1657+
assert tolerances['atol'] == params['ksp_atol']
16591658
assert tolerances['dtol'] == params['ksp_divtol']
16601659
assert tolerances['maxits'] == params['ksp_max_it']
16611660

@@ -1754,7 +1753,7 @@ def test_command_line_priority_1(self, command_line):
17541753
petsc_summary = summary.petsc
17551754
entry = petsc_summary.get_entry('section0', prefix)
17561755
for opt, val in expected[prefix]:
1757-
assert str(entry.KSPGetTolerances[opt.removeprefix('ksp_')]) == val
1756+
assert entry.KSPGetTolerances[opt.removeprefix('ksp_')] == val
17581757

17591758
@skipif('petsc')
17601759
def test_command_line_priority_2(self, command_line):
@@ -1774,7 +1773,7 @@ def test_command_line_priority_2(self, command_line):
17741773
petsc_summary = summary.petsc
17751774
entry = petsc_summary.get_entry('section0', prefix)
17761775
for opt, val in expected[prefix]:
1777-
assert str(entry.KSPGetTolerances[opt.removeprefix('ksp_')]) == val
1776+
assert entry.KSPGetTolerances[opt.removeprefix('ksp_')] == val
17781777

17791778

17801779
class TestHashing:

0 commit comments

Comments
 (0)