Skip to content

Commit 07bbfc2

Browse files
committed
compiler: Working petscsection ed bueler example apart from the byref with malloc
1 parent 81e4036 commit 07bbfc2

6 files changed

Lines changed: 111 additions & 37 deletions

File tree

devito/petsc/iet/builder.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,7 @@ def snes_ctx(self):
4040
def _setup(self):
4141
sobjs = self.solver_objs
4242
dmda = sobjs['dmda']
43-
mainctx = sobjs['userctx']
44-
45-
call_struct_callback = petsc_call(
46-
self.callback_builder.user_struct_callback.name, [Byref(mainctx)]
47-
)
48-
49-
calls_set_app_ctx = petsc_call('DMSetApplicationContext', [dmda, Byref(mainctx)])
43+
# mainctx = sobjs['userctx']
5044

5145
snes_create = petsc_call('SNESCreate', [sobjs['comm'], Byref(sobjs['snes'])])
5246

@@ -116,8 +110,8 @@ def _setup(self):
116110
mat_set_dm = petsc_call('MatSetDM', [sobjs['Jac'], dmda])
117111

118112
base_setup = dmda_calls + (
119-
call_struct_callback,
120-
calls_set_app_ctx,
113+
# call_struct_callback,
114+
# calls_set_app_ctx,
121115
snes_create,
122116
snes_options_prefix,
123117
set_options,
@@ -150,7 +144,15 @@ def _create_dmda_calls(self, dmda):
150144
dm_set_from_opts = petsc_call('DMSetFromOptions', [dmda])
151145
dm_setup = petsc_call('DMSetUp', [dmda])
152146
dm_mat_type = petsc_call('DMSetMatType', [dmda, 'MATSHELL'])
153-
return dmda_create, dm_set_from_opts, dm_setup, dm_mat_type
147+
mainctx = self.solver_objs['userctx']
148+
149+
call_struct_callback = petsc_call(
150+
self.callback_builder.user_struct_callback.name, [Byref(mainctx)]
151+
)
152+
153+
calls_set_app_ctx = petsc_call('DMSetApplicationContext', [dmda, Byref(mainctx)])
154+
155+
return dmda_create, dm_set_from_opts, dm_setup, dm_mat_type, call_struct_callback, calls_set_app_ctx
154156

155157
def _create_dmda(self, dmda):
156158
sobjs = self.solver_objs
@@ -344,6 +346,8 @@ class ConstrainedBCMixin:
344346
"""
345347
def _create_dmda_calls(self, dmda):
346348
sobjs = self.solver_objs
349+
# mainctx = sobjs['mainctx']
350+
mainctx = sobjs['userctx']
347351
# TODO: CLEAN UP
348352
dmda_create = self._create_dmda(dmda)
349353
# TODO: probs need to set the dm options prefix the same as snes?
@@ -373,12 +377,21 @@ def _create_dmda_calls(self, dmda):
373377

374378
dm_create_section_sf = petsc_call('DMCreateSectionSF', [dmda, sobjs['lsection'], sobjs['gsection']])
375379

380+
381+
call_struct_callback = petsc_call(
382+
self.callback_builder.user_struct_callback.name, [Byref(mainctx)]
383+
)
384+
385+
calls_set_app_ctx = petsc_call('DMSetApplicationContext', [dmda, Byref(mainctx)])
386+
376387
return (
377388
dmda_create,
378389
da_create_section,
379390
dm_set_from_opts,
380391
dm_setup,
381392
dm_mat_type,
393+
call_struct_callback,
394+
calls_set_app_ctx,
382395
count_bcs,
383396
set_point_bcs,
384397
get_local_section,

devito/petsc/iet/callbacks.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -755,23 +755,22 @@ def _create_set_point_bc_body(self, body):
755755
'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)]
756756
)
757757

758-
import numpy as np
759-
if self.options['index-mode'] == 'int32':
760-
dtype = np.int32
761-
else:
762-
dtype = np.int64
763-
from devito.passes.iet.linearization import Tracker
764-
765-
tracker = Tracker('basic', dtype, self.sregistry)
766-
767-
key = lambda f: f.name == 'bc'
768-
body = linearize_accesses(body, key0=key, tracker=tracker)
769-
770-
# will only be findexeds 'indexeds'
771-
findexeds = FindSymbols('indexeds').visit(body)
772-
mapper_findexeds = {i: i.linear_index for i in findexeds}
773-
774-
758+
# import numpy as np
759+
# if self.options['index-mode'] == 'int32':
760+
# dtype = np.int32
761+
# else:
762+
# dtype = np.int64
763+
# from devito.passes.iet.linearization import Tracker
764+
765+
# tracker = Tracker('basic', dtype, self.sregistry)
766+
# # from IPython import embed; embed()
767+
# key = lambda f: f.name == 'u'
768+
# body = linearize_accesses(body, key0=key, tracker=tracker)
769+
770+
# # will only be findexeds 'indexeds'
771+
# findexeds = FindSymbols('indexeds').visit(body)
772+
# mapper_findexeds = {i: i.linear_index for i in findexeds}
773+
775774
# from IPython import embed; embed()
776775

777776
# findexeds =
@@ -789,7 +788,11 @@ def _create_set_point_bc_body(self, body):
789788
'ISCreateGeneral', [comm, sobjs['numBC'], sobjs['bcPointsArr'], 'PETSC_OWN_POINTER', Byref(sobjs['bcPointsIS'])]
790789
)
791790

792-
malloc = petsc_call(
791+
malloc_bc_points_arr = petsc_call(
792+
'PetscMalloc1', [sobjs['numBC'], sobjs['bcPointsArr']]
793+
)
794+
795+
malloc_bc_points = petsc_call(
793796
'PetscMalloc1', [1, sobjs['bcPoints']]
794797
)
795798

@@ -798,7 +801,7 @@ def _create_set_point_bc_body(self, body):
798801
set_point_bc = petsc_call(
799802
'DMDASetPointBC', [dmda, 1, sobjs['bcPoints'], Null]
800803
)
801-
body = body._rebuild(body=body.body + (is_create_general, malloc, dummy_expr, set_point_bc))
804+
body = body._rebuild(body=(malloc_bc_points_arr,)+ body.body + (is_create_general, malloc_bc_points, dummy_expr, set_point_bc))
802805

803806
stacks = (
804807
dm_get_local_info,
@@ -821,7 +824,7 @@ def _create_set_point_bc_body(self, body):
821824

822825
subs[Counter._C_symbol] = sobjs['bcPointsArr'].indexed[sobjs['k_iter']]
823826

824-
body = Uxreplace(mapper_findexeds).visit(body)
827+
# body = Uxreplace(mapper_findexeds).visit(body)
825828
body = Uxreplace(subs).visit(body)
826829

827830
return body

devito/petsc/iet/passes.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
)
1010
from devito.symbolics import Byref, Macro, Null, FieldFromPointer
1111
from devito.types.basic import DataSymbol
12+
from devito.types.misc import FIndexed
1213
import devito.logger
14+
from devito.passes.iet.linearization import linearize_accesses
1315

1416
from devito.petsc.types import (
1517
MultipleFieldData, Initialize, Finalize, ArgvSymbol, MainUserStruct,
@@ -126,6 +128,61 @@ def lower_petsc_symbols(iet, **kwargs):
126128
# Rebuild `MainUserStruct` and update iet accordingly
127129
rebuild_parent_user_struct(iet, mapper=callback_struct_mapper)
128130

131+
# from IPython import embed; embed()
132+
133+
134+
iet = linear_indices(iet, **kwargs)
135+
136+
############ tmp
137+
# import numpy as np
138+
# if kwargs['options']['index-mode'] == 'int32':
139+
# dtype = np.int32
140+
# else:
141+
# dtype = np.int64
142+
# from devito.passes.iet.linearization import Tracker
143+
144+
# tracker = Tracker('basic', dtype, kwargs['sregistry'])
145+
146+
# key = lambda f: f.name == 'bc'
147+
# body = linearize_accesses(body, key0=key, tracker=tracker)
148+
149+
# # will only be findexeds 'indexeds'
150+
# findexeds = FindSymbols('indexeds').visit(body)
151+
# mapper_findexeds = {i: i.linear_index for i in findexeds}
152+
153+
# iet =
154+
155+
156+
@iet_pass
157+
def linear_indices(iet, **kwargs):
158+
159+
if not iet.name.startswith("SetPointBCs"):
160+
return iet, {}
161+
162+
import numpy as np
163+
if kwargs['options']['index-mode'] == 'int32':
164+
dtype = np.int32
165+
else:
166+
dtype = np.int64
167+
from devito.passes.iet.linearization import Tracker
168+
169+
tracker = Tracker('basic', dtype, kwargs['sregistry'])
170+
# from IPython import embed; embed()
171+
key = lambda f: f.name == 'u'
172+
iet = linearize_accesses(iet, key0=key, tracker=tracker)
173+
# from IPython import embed; embed()
174+
# will only be findexeds 'indexeds'
175+
findexeds = [i for i in FindSymbols('indexeds').visit(iet) if isinstance(i, FIndexed)]
176+
mapper_findexeds = {i: i.linear_index for i in findexeds}
177+
178+
179+
iet = Uxreplace(mapper_findexeds).visit(iet)
180+
181+
182+
# from IPython import embed; embed()
183+
184+
return iet, {}
185+
129186

130187
@iet_pass
131188
def rebuild_child_user_struct(iet, mapper, **kwargs):

devito/petsc/types/metadata.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -789,8 +789,9 @@ def _make_point_bc_expr(self, expr):
789789
# numBC = PetscInt(name='numBC2')
790790
if isinstance(expr, EssentialBC):
791791
assert expr.lhs == self.target
792+
# from IPython import embed; embed()
792793
return PointEssentialBC(
793-
Counter, expr.rhs,
794+
Counter, self.target,
794795
subdomain=expr.subdomain
795796
)
796797
else:

devito/types/misc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
# Moved in 1.13
99
from sympy.core.basic import ordering_of_classes
1010

11-
from devito.types import Array, CompositeObject, Indexed, Symbol, LocalObject
12-
from devito.types.basic import IndexedData
11+
from devito.types import Array, CompositeObject, Indexed, Symbol, LocalObject, ArrayObject
12+
from devito.types.basic import IndexedData, DataSymbol
1313
from devito.tools import CustomDtype, Pickable, as_tuple, frozendict
1414

1515
__all__ = ['Timer', 'Pointer', 'VolatileInt', 'FIndexed', 'Wildcard', 'Fence',
@@ -172,8 +172,8 @@ def linear_index(self):
172172

173173

174174
# the special postindex type sould live in this file i think
175-
class PostIncrementIndex(Symbol):
176-
pass
175+
class PostIncrementIndex(LocalObject):
176+
dtype = np.int32
177177

178178

179179
class Global(Symbol):

examples/petsc/Poisson/ed_bueler_2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,11 @@ def exact(x, y):
111111

112112
with switchconfig(log_level='DEBUG'):
113113
op = Operator(petsc, language='petsc')
114-
# summary = op.apply()
114+
summary = op.apply()
115115
# print(op.arguments())
116116

117117

118-
print(op.ccode)
118+
# print(op.ccode)
119119
# iters = summary.petsc[('section0', 'poisson_2d')].KSPGetIterationNumber
120120

121121
u_exact = Function(name='u_exact', grid=grid, space_order=2)

0 commit comments

Comments
 (0)