Skip to content

Commit e911da9

Browse files
committed
compiler: Revamp allocation of transient Bundles
1 parent 8d03952 commit e911da9

1 file changed

Lines changed: 21 additions & 5 deletions

File tree

devito/passes/iet/definitions.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,12 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
256256

257257
arity_param = Symbol(name='arity', dtype=size_t)
258258
arity_arg = SizeOf(obj.indexed._C_typedata)
259+
ndims_param = Symbol(name='ndims', dtype=size_t)
260+
ndims_arg = obj.ndim
261+
shape_param = Array(name=f'{obj.name}_shape', dtype=np.uint64,
262+
dimensions=(Dimension(name='d'),), scope='rvalue')
263+
shape_arg = ListInitializer(obj.c0.symbolic_shape, dtype=shape_param.dtype)
264+
259265
ffp1 = FieldFromPointer(obj._C_field_shape, obj._C_symbol)
260266
ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol)
261267
ffp3 = FieldFromPointer(obj._C_field_nbytes, obj._C_symbol)
@@ -271,11 +277,18 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
271277
nbytes = SizeOf(obj._C_size_type)*obj.ndim
272278
alloc1 = self.langbb['host-alloc'](memptr, alignment, nbytes)
273279

274-
# Initialize the Bundle struct
275-
init = [*[DummyExpr(IndexedPointer(ffp1, i), s)
276-
for i, s in enumerate(obj.c0.symbolic_shape)],
277-
DummyExpr(ffp2, obj.size),
278-
DummyExpr(ffp3, ffp2*arity_param)]
280+
# Initialize the Bundle metadata
281+
dim, = shape_param.dimensions
282+
init = [DummyExpr(ffp2, 1)]
283+
init.append(Iteration(
284+
List(body=(
285+
DummyExpr(IndexedPointer(ffp1, dim), shape_param[dim]),
286+
DummyExpr(ffp2, ffp2*shape_param[dim])
287+
)),
288+
dim,
289+
ndims_param - 1,
290+
))
291+
init.append(DummyExpr(ffp3, ffp2*arity_param))
279292

280293
# Free all of the allocated data
281294
frees = [self.langbb['host-free'](ffp1),
@@ -290,6 +303,8 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
290303
efunc0 = make_callable(name, body, retval=obj)
291304
args = list(efunc0.parameters)
292305
args[args.index(arity_param)] = arity_arg
306+
args[args.index(ndims_param)] = ndims_arg
307+
args[args.index(shape_param)] = shape_arg
293308
alloc = Call(name, args, retobj=obj)
294309

295310
# Same story for the frees
@@ -298,6 +313,7 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
298313
free = Call(name, efunc1.parameters)
299314

300315
storage.update(obj, site, allocs=alloc, frees=free, efuncs=(efunc0, efunc1))
316+
storage.include(self.langbb['header-array'])
301317

302318
def _alloc_object_array_on_low_lat_mem(self, site, obj, storage):
303319
"""

0 commit comments

Comments
 (0)