Skip to content

Commit 5cce561

Browse files
committed
compiler: Progress with the SetPointBCs0 callback petsc
1 parent 913b3b4 commit 5cce561

3 files changed

Lines changed: 53 additions & 5 deletions

File tree

devito/passes/iet/linearization.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def linearize(graph, **kwargs):
2525

2626
mode = options.get('linearize')
2727
maybe_callback = kwargs.pop('callback', mode)
28-
28+
# from IPython import embed; embed()
2929
if not maybe_callback:
3030
return
3131
elif callable(maybe_callback):
@@ -212,7 +212,7 @@ def linearize_accesses(iet, key0, tracker=None):
212212
indexeds = FindSymbols('indexeds').visit(iet)
213213
needs = filter_ordered(i.function for i in indexeds if key0(i.function))
214214
needs = sorted(needs, key=lambda f: len(f.dimensions), reverse=True)
215-
215+
# from IPython import embed; embed()
216216
# Update unique sizes and strides
217217
tracker.update(needs)
218218

@@ -230,9 +230,10 @@ def linearize_accesses(iet, key0, tracker=None):
230230
continue
231231

232232
v = generate_linearization(f, i, tracker)
233+
# from IPython import embed; embed()
233234
if v is not None:
234235
subs[i] = v
235-
236+
# from IPython import embed; embed()
236237
iet = Uxreplace(subs).visit(iet)
237238

238239
# 2) What `iet` *offers*

devito/petsc/iet/callbacks.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from devito.types.basic import AbstractFunction
1212
from devito.types import Dimension, Temp, TempArray
1313
from devito.tools import filter_ordered
14+
from devito.passes.iet.linearization import linearize_accesses
1415

1516
from devito.petsc.iet.nodes import PETScCallable, MatShellSetOp, petsc_call
1617
from devito.petsc.types import DMCast, MainUserStruct, CallbackUserStruct
@@ -28,6 +29,7 @@ def __init__(self, **kwargs):
2829

2930
self.rcompile = kwargs.get('rcompile', None)
3031
self.sregistry = kwargs.get('sregistry', None)
32+
self.options = kwargs.get('options', {})
3133
self.concretize_mapper = kwargs.get('concretize_mapper', {})
3234
self.time_dependence = kwargs.get('time_dependence')
3335
self.objs = kwargs.get('objs')
@@ -753,6 +755,25 @@ def _create_set_point_bc_body(self, body):
753755
'DMDAGetLocalInfo', [dmda, Byref(linsolve_expr.localinfo)]
754756
)
755757

758+
import numpy as np
759+
if self.options['index-mode'] == 'int32':
760+
dtype = np.int32
761+
else:
762+
dtype = np.int64
763+
from devito.passes.iet.linearization import Tracker
764+
765+
tracker = Tracker('basic', dtype, self.sregistry)
766+
767+
key = lambda f: f.name == 'bc'
768+
body = linearize_accesses(body, key0=key, tracker=tracker)
769+
770+
# will only be findexeds 'indexeds'
771+
findexeds = FindSymbols('indexeds').visit(body)
772+
mapper_findexeds = {i: i.linear_index for i in findexeds}
773+
774+
# from IPython import embed; embed()
775+
776+
# findexeds =
756777
body = self.time_dependence.uxreplace_time(body)
757778

758779
fields = get_user_struct_fields(body)
@@ -762,7 +783,12 @@ def _create_set_point_bc_body(self, body):
762783
'DMGetApplicationContext', [dmda, Byref(ctx._C_symbol)]
763784
)
764785

765-
# body = body._rebuild(body=body.body)
786+
comm = sobjs['comm']
787+
is_create_general = petsc_call(
788+
'ISCreateGeneral', [comm, sobjs['numBC'], sobjs['bcPointsArr'], 'PETSC_OWN_POINTER']
789+
)
790+
791+
body = body._rebuild(body=body.body + (is_create_general,))
766792

767793
stacks = (
768794
dm_get_local_info,
@@ -786,7 +812,10 @@ def _create_set_point_bc_body(self, body):
786812

787813
subs[Counter._C_symbol] = sobjs['bcPointsArr'].indexed[sobjs['k_iter']]
788814

789-
return Uxreplace(subs).visit(body)
815+
body = Uxreplace(mapper_findexeds).visit(body)
816+
body = Uxreplace(subs).visit(body)
817+
818+
return body
790819

791820
def _make_user_struct_callback(self):
792821
"""

devito/types/misc.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,31 @@ def bind(self, pname):
149149
findexed = self.func(accessor=accessor)
150150

151151
return ((define, expr), findexed)
152+
153+
154+
@property
155+
def linear_index(self):
156+
f = self.function
157+
strides_map = self.strides_map
158+
indices = self.indices
159+
160+
items = [
161+
idx * strides_map[d]
162+
for idx, d in zip(indices, f.dimensions[1:])
163+
]
164+
items.append(indices[-1])
165+
166+
return sympy.Add(*items, evaluate=False)
152167

153168
func = Pickable._rebuild
154169

155170
# Pickling support
156171
__reduce_ex__ = Pickable.__reduce_ex__
157172

158173

174+
# the special postindex type sould live in this file i think
175+
176+
159177
class Global(Symbol):
160178

161179
"""

0 commit comments

Comments
 (0)