Skip to content

Commit 81e4036

Browse files
committed
rffraq
1 parent 3089dc5 commit 81e4036

4 files changed

Lines changed: 33 additions & 20 deletions

File tree

devito/passes/iet/linearization.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def linearize_accesses(iet, key0, tracker=None):
212212
indexeds = FindSymbols('indexeds').visit(iet)
213213
needs = filter_ordered(i.function for i in indexeds if key0(i.function))
214214
needs = sorted(needs, key=lambda f: len(f.dimensions), reverse=True)
215-
# from IPython import embed; embed()
215+
216216
# Update unique sizes and strides
217217
tracker.update(needs)
218218

@@ -230,16 +230,17 @@ def linearize_accesses(iet, key0, tracker=None):
230230
continue
231231

232232
v = generate_linearization(f, i, tracker)
233-
# from IPython import embed; embed()
233+
234234
if v is not None:
235235
subs[i] = v
236-
# from IPython import embed; embed()
236+
237237
iet = Uxreplace(subs).visit(iet)
238238

239239
# 2) What `iet` *offers*
240240
# E.g. `{x_fsz0 -> u_vec->size[1]}`
241241
defines = FindSymbols('defines').visit(iet)
242242
offers = filter_ordered(i for i in defines if key0(i.function))
243+
# from IPython import embed; embed()
243244
instances = {}
244245
for i in offers:
245246
f = i.function
@@ -294,7 +295,7 @@ def linearize_accesses(iet, key0, tracker=None):
294295
if stmts:
295296
body = iet.body._rebuild(strides=stmts)
296297
iet = iet._rebuild(body=body)
297-
298+
# from IPython import embed; embed()
298299
return iet
299300

300301

devito/petsc/iet/builder.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ def snes_ctx(self):
4040
def _setup(self):
4141
sobjs = self.solver_objs
4242
dmda = sobjs['dmda']
43+
mainctx = sobjs['userctx']
44+
45+
call_struct_callback = petsc_call(
46+
self.callback_builder.user_struct_callback.name, [Byref(mainctx)]
47+
)
48+
49+
calls_set_app_ctx = petsc_call('DMSetApplicationContext', [dmda, Byref(mainctx)])
4350

4451
snes_create = petsc_call('SNESCreate', [sobjs['comm'], Byref(sobjs['snes'])])
4552

@@ -105,18 +112,12 @@ def _setup(self):
105112

106113
dmda_calls = self._create_dmda_calls(dmda)
107114

108-
mainctx = sobjs['userctx']
109-
110-
call_struct_callback = petsc_call(
111-
self.callback_builder.user_struct_callback.name, [Byref(mainctx)]
112-
)
113-
114115
# TODO: maybe don't need to explictly set this
115116
mat_set_dm = petsc_call('MatSetDM', [sobjs['Jac'], dmda])
116117

117-
calls_set_app_ctx = petsc_call('DMSetApplicationContext', [dmda, Byref(mainctx)])
118-
119118
base_setup = dmda_calls + (
119+
call_struct_callback,
120+
calls_set_app_ctx,
120121
snes_create,
121122
snes_options_prefix,
122123
set_options,
@@ -131,9 +132,7 @@ def _setup(self):
131132
matvec_operation,
132133
formfunc_operation,
133134
snes_set_options,
134-
call_struct_callback,
135135
mat_set_dm,
136-
calls_set_app_ctx,
137136
BlankLine
138137
)
139138
extended_setup = self._extend_setup()
@@ -257,8 +256,6 @@ def _setup(self):
257256
# TODO: maybe don't need to explictly set this
258257
mat_set_dm = petsc_call('MatSetDM', [sobjs['Jac'], dmda])
259258

260-
calls_set_app_ctx = petsc_call('DMSetApplicationContext', [dmda, Byref(mainctx)])
261-
262259
create_field_decomp = petsc_call(
263260
'DMCreateFieldDecomposition',
264261
[dmda, Byref(sobjs['nfields']), Null, Byref(sobjs['fields']),

devito/petsc/iet/callbacks.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from devito.types.misc import PostIncrementIndex
1313
from devito.types import Dimension, Temp, TempArray
1414
from devito.tools import filter_ordered
15-
from devito.passes.iet.linearization import linearize_accesses
15+
from devito.passes.iet.linearization import linearize_accesses, Stride
1616

1717
from devito.petsc.iet.nodes import PETScCallable, MatShellSetOp, petsc_call
1818
from devito.petsc.types import DMCast, MainUserStruct, CallbackUserStruct
@@ -771,6 +771,7 @@ def _create_set_point_bc_body(self, body):
771771
findexeds = FindSymbols('indexeds').visit(body)
772772
mapper_findexeds = {i: i.linear_index for i in findexeds}
773773

774+
774775
# from IPython import embed; embed()
775776

776777
# findexeds =
@@ -791,7 +792,13 @@ def _create_set_point_bc_body(self, body):
791792
malloc = petsc_call(
792793
'PetscMalloc1', [1, sobjs['bcPoints']]
793794
)
794-
body = body._rebuild(body=body.body + (is_create_general,malloc))
795+
796+
dummy_expr = DummyExpr(sobjs['bcPoints'].indexed[0], sobjs['bcPointsIS'])
797+
798+
set_point_bc = petsc_call(
799+
'DMDASetPointBC', [dmda, 1, sobjs['bcPoints'], Null]
800+
)
801+
body = body._rebuild(body=body.body + (is_create_general, malloc, dummy_expr, set_point_bc))
795802

796803
stacks = (
797804
dm_get_local_info,
@@ -811,7 +818,6 @@ def _create_set_point_bc_body(self, body):
811818
# Replace non-function data with pointer to data in struct
812819
subs = {i._C_symbol: FieldFromPointer(i._C_symbol, ctx) for
813820
i in fields if not isinstance(i.function, AbstractFunction)}
814-
815821

816822
subs[Counter._C_symbol] = sobjs['bcPointsArr'].indexed[sobjs['k_iter']]
817823

@@ -1303,7 +1309,7 @@ def zero_vector(vec):
13031309
def get_user_struct_fields(iet):
13041310
fields = [f.function for f in FindSymbols('basics').visit(iet)]
13051311
from devito.types.basic import LocalType
1306-
avoid = (Temp, TempArray, LocalType, PostIncrementIndex)
1312+
avoid = (Temp, TempArray, LocalType, PostIncrementIndex, Stride)
13071313
fields = [f for f in fields if not isinstance(f.function, avoid)]
13081314
fields = [
13091315
f for f in fields if not (f.is_Dimension and not (f.is_Time or f.is_Modulo))

devito/petsc/types/object.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from devito.petsc.iet.nodes import petsc_call
1313

1414

15+
# TODO: unnecessary use of "CALLBACK" types - just create a simple way of destroying or not destroying a certain type
16+
17+
1518
class PetscMixin:
1619
@property
1720
def _C_free_priority(self):
@@ -208,6 +211,12 @@ class SingleIS(PetscObject):
208211
dtype = CustomDtype('IS')
209212

210213

214+
# class SingleISDestroy(SingleIS):
215+
# @property
216+
# def _C_free(self):
217+
# return petsc_call('ISDestroy', [Byref(self.function)])
218+
219+
211220
class PetscSectionGlobal(PetscObject):
212221
dtype = CustomDtype('PetscSection')
213222

0 commit comments

Comments
 (0)