Skip to content

Commit 735e60a

Browse files
committed
dsl: Fix hashing for solveexpr
1 parent eb8043f commit 735e60a

6 files changed

Lines changed: 63 additions & 150 deletions

File tree

devito/petsc/iet/passes.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def lower_petsc(iet, **kwargs):
7575
# individual PetscOptions
7676
# set_solver_option(efuncs)
7777
# List of all callbacks that clear PetscOptions
78-
# from IPython import embed; embed()
78+
7979
# TODO: throw a warning/error if the user passes a solver in with the same options_prefix
8080
# it's going to lead to weird solver option behaviour. Note, if you use the options_prefix across
8181
# different Operator runs, it will not be an issue
@@ -92,8 +92,6 @@ def lower_petsc(iet, **kwargs):
9292

9393
efuncs.update(builder.cbbuilder.efuncs)
9494

95-
# clear_options.append(builder.cbbuilder._clear_options_efunc)
96-
# from IPython import embed; embed() # noqa: E402
9795
clear_options.extend((petsc_call(
9896
builder.cbbuilder._clear_options_efunc.name, []
9997
),))

devito/petsc/solve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def build_expr(self):
8383
user_prefix=self.user_prefix,
8484
formatted_prefix=self.formatted_prefix
8585
)
86-
86+
# from IPython import embed; embed()
8787
return PetscEq(target, linear_solve)
8888

8989
def linear_solve_args(self):

devito/petsc/solver_parameters.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,16 +45,14 @@ def linear_solver_parameters(solver_parameters):
4545
return processed
4646

4747

48-
# _options_prefix_counter = itertools.count()
48+
_options_prefix_counter = itertools.count()
4949

5050
# TODO: add a default options prefix if not provided
5151
def format_options_prefix(options_prefix):
5252
# NOTE: Modified from the `OptionsManager` inside petsctools
5353
if options_prefix is None:
54-
# options_prefix = f"devito_{next(_options_prefix_counter)}_"
55-
options_prefix = ""
54+
options_prefix = f"devito_{next(_options_prefix_counter)}_"
5655
else:
5756
if len(options_prefix) and not options_prefix.endswith("_"):
5857
options_prefix += "_"
59-
# options_prefix = options_prefix
6058
return options_prefix

devito/petsc/types/types.py

Lines changed: 19 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,10 @@
1717
class MetaData(sympy.Function, Reconstructable):
1818
def __new__(cls, expr, **kwargs):
1919
with sympy_mutex:
20-
obj = sympy.Basic.__new__(cls, expr)
21-
obj._expr = expr
20+
obj = sympy.Function.__new__(cls, expr)
21+
obj.expr = expr
2222
return obj
2323

24-
@property
25-
def expr(self):
26-
return self._expr
27-
2824

2925
class Initialize(MetaData):
3026
pass
@@ -48,59 +44,36 @@ def __new__(cls, expr, solver_parameters=None,
4844
user_prefix=None, formatted_prefix=None, **kwargs):
4945

5046
with sympy_mutex:
51-
obj = sympy.Function.__new__(cls, expr)
47+
if isinstance(expr, tuple):
48+
expr = sympy.Tuple(*expr)
49+
obj = sympy.Basic.__new__(cls, expr)
5250

53-
obj._expr = expr
54-
obj._solver_parameters = solver_parameters
55-
obj._field_data = field_data if field_data else FieldData()
56-
obj._time_mapper = time_mapper
57-
obj._localinfo = localinfo
58-
obj._user_prefix = user_prefix
59-
obj._formatted_prefix = formatted_prefix
51+
obj.expr = expr
52+
obj.solver_parameters = solver_parameters
53+
obj.field_data = field_data if field_data else FieldData()
54+
obj.time_mapper = time_mapper
55+
obj.localinfo = localinfo
56+
obj.user_prefix = user_prefix
57+
obj.formatted_prefix = formatted_prefix
6058
return obj
6159

6260
def __repr__(self):
63-
return "%s(%s)" % (self.__class__.__name__, self.expr)
61+
return f"{self.__class__.__name__}{self.expr}"
6462

6563
__str__ = __repr__
6664

6765
def _sympystr(self, printer):
6866
return str(self)
6967

70-
def __hash__(self):
71-
return hash(self.expr)
68+
__hash__ = sympy.Basic.__hash__
69+
70+
def _hashable_content(self):
71+
return (self.expr, self.formatted_prefix)
7272

7373
def __eq__(self, other):
7474
return (isinstance(other, SolveExpr) and
75-
self.expr == other.expr)
76-
77-
@property
78-
def expr(self):
79-
return self._expr
80-
81-
@property
82-
def field_data(self):
83-
return self._field_data
84-
85-
@property
86-
def solver_parameters(self):
87-
return self._solver_parameters
88-
89-
@property
90-
def time_mapper(self):
91-
return self._time_mapper
92-
93-
@property
94-
def localinfo(self):
95-
return self._localinfo
96-
97-
@property
98-
def user_prefix(self):
99-
return self._user_prefix
100-
101-
@property
102-
def formatted_prefix(self):
103-
return self._formatted_prefix
75+
self.expr == other.expr and
76+
self.formatted_prefix == other.formatted_prefix)
10477

10578
@property
10679
def grid(self):
@@ -114,48 +87,6 @@ def eval(cls, *args):
11487

11588

11689

117-
118-
119-
# class SolveExpr(MetaData):
120-
121-
# # __rargs__ = ('expr',)
122-
# # # __rkwargs__ = ('solver_parameters', 'field_data', 'time_mapper',
123-
# # # 'localinfo', 'user_prefix', 'formatted_prefix')
124-
125-
# # __rkwargs__ = ('user_prefix',)
126-
127-
# def __new__(cls, expr, solver_parameters=None,
128-
# field_data=None, time_mapper=None, localinfo=None,
129-
# user_prefix=None, formatted_prefix=None, **kwargs):
130-
131-
# # def __new__(cls, expr,
132-
# # user_prefix=None, **kwargs):
133-
134-
# obj = sympy.Basic.__new__(cls, expr)
135-
# obj.solver_parameters = solver_parameters or {}
136-
# obj.field_data = field_data if field_data else FieldData()
137-
# obj.time_mapper = time_mapper
138-
# obj.localinfo = localinfo
139-
# obj.user_prefix = user_prefix
140-
# obj.formatted_prefix = formatted_prefix
141-
# return obj
142-
143-
# # def __hash__(self):
144-
# # return hash((self.expr, self.user_prefix, self.solver_parameters))
145-
# def __hash__(self):
146-
# return hash(self.expr)
147-
148-
# def __eq__(self, other):
149-
# return (isinstance(other, LinearSolveExpr) and
150-
# self.expr == other.expr)
151-
152-
# @property
153-
# def expr(self):
154-
# return self.args[0]
155-
156-
157-
158-
15990
class LinearSolveExpr(SolveExpr):
16091
"""
16192
Linear problems are handled by setting the SNESType to 'ksponly',

examples/petsc/solver_options.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,17 @@
5555
petsc1 = PETScSolve(eq, target=e, solver_parameters=params1, options_prefix='pde1')
5656
petsc2 = PETScSolve(eq, target=e, solver_parameters=params2, options_prefix='pde1')
5757

58-
with switchconfig(language='petsc'):
58+
# from IPython import embed; embed()
59+
60+
# with switchconfig(language='petsc'):
5961

60-
op1 = Operator([petsc1])
61-
op2 = Operator(petsc2)
62-
summary1 = op1.apply()
63-
summary2 = op2.apply()
62+
# op1 = Operator([petsc1, petsc2])
63+
# # op2 = Operator(petsc2)
64+
# # summary1 = op1.apply()
65+
# # summary2 = op2.apply()
6466

65-
# print(op1.ccode)
66-
# print(op2.ccode)
67+
# print(op1.ccode)
68+
# # print(op2.ccode)
6769

6870

6971

tests/test_petsc.py

Lines changed: 31 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -208,26 +208,6 @@ def test_petsc_cast():
208208
'(PetscScalar (*)[info.gym][info.gxm]) x_f3_vec;' in str(op3.ccode)
209209

210210

211-
# # TODO: add more thorough tests for solver_parameters
212-
# # @skipif('petsc')
213-
# # def test_LinearSolveExpr():
214-
215-
# # grid = Grid((2, 2), dtype=np.float64)
216-
217-
# # f = Function(name='f', grid=grid, space_order=2)
218-
# # g = Function(name='g', grid=grid, space_order=2)
219-
220-
# # eqn = Eq(f, g.laplace)
221-
222-
# # linsolveexpr = LinearSolveExpr(eqn.rhs, field_data=FieldData(target=f))
223-
224-
# # # Check the solver parameters
225-
# # assert linsolveexpr.solver_parameters == \
226-
# # {'snes_type': 'ksponly', 'ksp_type': 'gmres', 'pc_type': 'none',
227-
# # 'ksp_rtol': 1e-05, 'ksp_atol': 1e-50, 'ksp_divtol': 100000.0,
228-
# # 'ksp_max_it': 10000}
229-
230-
231211
@skipif('petsc')
232212
def test_dmda_create():
233213

@@ -794,7 +774,7 @@ def test_coupled_vs_non_coupled(self, eq1, eq2, so):
794774
# less callback functions than solving them separately.
795775
# TODO: As noted in the other test, some efuncs are not reused
796776
# where reuse is possible, investigate.
797-
assert len(callbacks1) == 10
777+
assert len(callbacks1) == 12
798778
assert len(callbacks2) == 8
799779

800780
# Check field_data type
@@ -1494,8 +1474,11 @@ def test_logging_multiple_solves(self):
14941474
class TestSolverParameters:
14951475

14961476
@skipif('petsc')
1497-
def setup_method(self):
1498-
"""Setup grid and functions shared across tests in this class"""
1477+
def setup_class(self):
1478+
"""
1479+
Setup grid, functions and equations shared across
1480+
tests in this class
1481+
"""
14991482
grid = Grid(shape=(11, 11), dtype=np.float64)
15001483
self.e, self.f, self.g, self.h = [
15011484
Function(name=n, grid=grid, space_order=2)
@@ -1506,44 +1489,27 @@ def setup_method(self):
15061489

15071490
@skipif('petsc')
15081491
def test_differing_solver_params(self):
1509-
"""
1510-
"""
1492+
# Explicitly set the solver parameters
15111493
solver1 = PETScSolve(
15121494
self.eq1, target=self.e, solver_parameters={'ksp_rtol': '1e-10'}
15131495
)
1496+
# This solver uses the defaults
15141497
solver2 = PETScSolve(self.eq2, target=self.g)
15151498

15161499
with switchconfig(language='petsc'):
15171500
op = Operator([solver1, solver2])
15181501

15191502
# Check that there are two `SetPetscOptions` callbacks since the solver
1520-
# parameters are different for each solver.
1503+
# parameters are different for each solver
15211504
assert 'SetPetscOptions0' in op._func_table
15221505
assert 'SetPetscOptions1' in op._func_table
15231506

1524-
assert 'PetscCall(PetscOptionsSetValue(NULL,"-ksp_rtol","1e-10"));' \
1507+
assert '_ksp_rtol","1e-10"' \
15251508
in str(op._func_table['SetPetscOptions0'].root)
15261509

1527-
assert 'PetscCall(PetscOptionsSetValue(NULL,"-ksp_rtol","1e-05"));' \
1510+
assert '_ksp_rtol","1e-05"' \
15281511
in str(op._func_table['SetPetscOptions1'].root)
15291512

1530-
@skipif('petsc')
1531-
def test_options_with_no_value(self):
1532-
"""
1533-
Test solver parameters that do not require a value, such as
1534-
`snes_view`.
1535-
"""
1536-
solver1 = PETScSolve(
1537-
self.eq1, target=self.e, solver_parameters={'snes_view': None}
1538-
)
1539-
1540-
with switchconfig(language='petsc'):
1541-
op = Operator(solver1)
1542-
op.apply()
1543-
1544-
assert 'PetscCall(PetscOptionsSetValue(NULL,"-snes_view",NULL));' \
1545-
in str(op._func_table['SetPetscOptions0'].root)
1546-
15471513
@skipif('petsc')
15481514
def test_options_prefix(self):
15491515
solver1 = PETScSolve(self.eq1, self.e,
@@ -1556,7 +1522,7 @@ def test_options_prefix(self):
15561522
with switchconfig(language='petsc'):
15571523
op = Operator([solver1, solver2])
15581524

1559-
# Check the options prefix has been correctly set on each snes solver
1525+
# Check the options prefix has been correctly set for each snes solver
15601526
assert 'PetscCall(SNESSetOptionsPrefix(snes0,"poisson1_"));' in str(op)
15611527
assert 'PetscCall(SNESSetOptionsPrefix(snes1,"poisson2_"));' in str(op)
15621528

@@ -1567,6 +1533,24 @@ def test_options_prefix(self):
15671533
assert 'PetscCall(PetscOptionsSetValue(NULL,"-poisson2_ksp_rtol","1e-12"));' \
15681534
in str(op._func_table['SetPetscOptions1'].root)
15691535

1536+
1537+
@skipif('petsc')
1538+
def test_options_with_no_value(self):
1539+
"""
1540+
Test solver parameters that do not require a value, such as
1541+
`snes_view` and `ksp_view`
1542+
"""
1543+
solver = PETScSolve(
1544+
self.eq1, target=self.e, solver_parameters={'snes_view': None},
1545+
options_prefix='solver1'
1546+
)
1547+
with switchconfig(language='petsc'):
1548+
op = Operator(solver)
1549+
op.apply()
1550+
1551+
assert 'PetscCall(PetscOptionsSetValue(NULL,"-solver1_snes_view",NULL));' \
1552+
in str(op._func_table['SetPetscOptions0'].root)
1553+
15701554
@skipif('petsc')
15711555
@pytest.mark.parametrize('log_level', ['PERF', 'DEBUG'])
15721556
def test_tolerances(self, log_level):
@@ -1594,11 +1578,11 @@ def test_tolerances(self, log_level):
15941578
assert tolerances['maxits'] == params['ksp_max_it']
15951579

15961580

1597-
15981581
# TODO: ADD TEST TO CHECK FOR DEFAULT LINEAR SOLVER PARAMETERS
15991582
# TODO: add test to check that the correct options are unset properly
16001583
# TODO: Add a test to check that the command line args override anything set
16011584
# in the solver_parameters dictionary
1585+
# TODO: add hashing tests to petscsolve
16021586

16031587
# @skipif('petsc')
16041588
# def test_command_line_priority(self):

0 commit comments

Comments
 (0)