Skip to content

Commit a50a1d4

Browse files
committed
WIP WIP
1 parent 23eabda commit a50a1d4

3 files changed

Lines changed: 27 additions & 19 deletions

File tree

devito/passes/iet/definitions.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -177,21 +177,23 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
177177
"""
178178
decl = Definition(obj)
179179

180+
# NOTE: the `arity` is calculated such as `sizeof(float3)/sizeof(float)`
181+
# for portability reasons (since we don't know the size of compound
182+
# types a priori)
180183
arity_param = Symbol(name='arity', dtype=size_t)
181-
arity_arg = SizeOf(obj.indexed._C_typedata)
184+
arity_arg = (SizeOf(obj.indexed._C_typedata) /
185+
SizeOf(obj.c0.indexed._C_typedata))
182186
ndims_param = Symbol(name='ndims', dtype=size_t)
183187
ndims_arg = obj.ndim
184188
shape_param = Array(name=f'{obj.name}_shape', dtype=np.uint64,
185189
dimensions=(Dimension(name='d'),), scope='rvalue')
186190
shape_arg = ListInitializer(obj.c0.symbolic_shape, dtype=shape_param.dtype)
187-
sizeofelem_param = Symbol(name='sizeofelem', dtype=size_t)
188-
sizeofelem_arg = 1 #TODO
189191

190192
ffp0 = FieldFromPointer(obj._C_field_data, obj._C_symbol)
191193
ffp1 = FieldFromPointer(obj._C_field_shape, obj._C_symbol)
192194
ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol)
193195
ffp3 = FieldFromPointer(obj._C_field_nbytes, obj._C_symbol)
194-
ffp4 = FieldFromPointer(obj._C_field_sizeofelem, obj._C_symbol)
196+
ffp4 = FieldFromPointer(obj._C_field_arity, obj._C_symbol)
195197

196198
# Allocate the Array struct
197199
memptr = VOID(Byref(obj._C_symbol), '**')
@@ -216,7 +218,7 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
216218
ndims_param - 1,
217219
))
218220
init.append(DummyExpr(ffp3, ffp2*arity_param))
219-
init.append(DummyExpr(ffp4, sizeofelem_param))
221+
init.append(DummyExpr(ffp4, arity_param))
220222

221223
# Allocate the underlying host data
222224
memptr = VOID(Byref(ffp0), '**')
@@ -241,7 +243,6 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
241243
args[args.index(arity_param)] = arity_arg
242244
args[args.index(ndims_param)] = ndims_arg
243245
args[args.index(shape_param)] = shape_arg
244-
args[args.index(sizeofelem_param)] = sizeofelem_arg #TODO
245246
alloc = Call(name, args, retobj=obj)
246247

247248
# Same story for the frees
@@ -259,21 +260,22 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
259260
"""
260261
decl = Definition(obj)
261262

263+
# NOTE: the `arity` is calculated such as `sizeof(float3)/sizeof(float)`
264+
# for portability reasons (since we don't know the size of compound
265+
# types a priori)
262266
arity_param = Symbol(name='arity', dtype=size_t)
263-
arity_arg = SizeOf(obj.indexed._C_typedata)
267+
arity_arg = (SizeOf(obj.indexed._C_typedata) /
268+
SizeOf(obj.c0.indexed._C_typedata))
264269
ndims_param = Symbol(name='ndims', dtype=size_t)
265270
ndims_arg = obj.ndim
266271
shape_param = Array(name=f'{obj.name}_shape', dtype=np.uint64,
267272
dimensions=(Dimension(name='d'),), scope='rvalue')
268273
shape_arg = ListInitializer(obj.c0.symbolic_shape, dtype=shape_param.dtype)
269-
sizeofelem_param = Symbol(name='sizeofelem', dtype=size_t)
270-
sizeofelem_arg = 1 #TODO
271-
from IPython import embed; embed()
272274

273275
ffp1 = FieldFromPointer(obj._C_field_shape, obj._C_symbol)
274276
ffp2 = FieldFromPointer(obj._C_field_size, obj._C_symbol)
275277
ffp3 = FieldFromPointer(obj._C_field_nbytes, obj._C_symbol)
276-
ffp4 = FieldFromPointer(obj._C_field_sizeofelem, obj._C_symbol)
278+
ffp4 = FieldFromPointer(obj._C_field_arity, obj._C_symbol)
277279

278280
# Allocate the Bundle struct
279281
memptr = VOID(Byref(obj._C_symbol), '**')
@@ -298,6 +300,7 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
298300
ndims_param - 1,
299301
))
300302
init.append(DummyExpr(ffp3, ffp2*arity_param))
303+
init.append(DummyExpr(ffp4, arity_param))
301304

302305
# Free all of the allocated data
303306
frees = [self.langbb['host-free'](ffp1),
@@ -314,7 +317,6 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
314317
args[args.index(arity_param)] = arity_arg
315318
args[args.index(ndims_param)] = ndims_arg
316319
args[args.index(shape_param)] = shape_arg
317-
args[args.index(sizeofelem_param)] = sizeofelem_arg #TODO
318320
alloc = Call(name, args, retobj=obj)
319321

320322
# Same story for the frees

devito/passes/iet/engine.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -361,24 +361,30 @@ def abstract_component_accesses(root, efuncs, sregistry=None):
361361
ca, = compaccs
362362
f = ca.function
363363

364-
# Access the same entry as `ca` but via pointer arithmetic instead
364+
arity = Symbol(name='arity', dtype=np.uint32, is_const=True)
365365
index = Symbol(name='index', dtype=np.uint32, is_const=True)
366-
indices = [Mul(f.ncomp, i, evaluate=False) for i in ca.indices] #TODO
366+
367+
# Access the same entry as `ca` but via pointer arithmetic instead
368+
indices = [Mul(arity, i, evaluate=False) for i in ca.indices]
367369
indices[-1] += index
368370
kernel1 = Uxreplace({ca: f.c0.indexed[indices]}).visit(kernel)
369371

370372
# Update the parameters list
371-
parameters = [*kernel1.parameters, index]
373+
parameters = [*kernel1.parameters, index, arity]
372374
parameters[parameters.index(f.indexed)] = f.c0.indexed
373375
kernel1 = kernel1._rebuild(parameters=parameters)
374376

375377
# Update the Call site
376-
args = [*call.arguments, ca.index]
377-
args[args.index(f.dmap)] = Cast(f.dmap, dtype=f.c0.indexed._C_ctype,
378+
ffp0 = FieldFromPointer(f._C_field_dmap, f._C_symbol)
379+
ffp1 = FieldFromPointer(f._C_field_arity, f._C_symbol)
380+
args = [*call.arguments, ca.index, ffp1]
381+
args[args.index(f.dmap)] = Cast(ffp0, dtype=f.c0.indexed._C_ctype,
378382
reinterpret=True)
379383
call1 = call._rebuild(arguments=args)
380384
efunc1 = Transformer({call: call1}).visit(efunc)
381385

386+
#TODO Propagate index through call stack until haloupdate...
387+
382388
# Store the new Callables
383389
processed.update({kernel1.name: kernel1, k: efunc1})
384390

devito/types/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ class MappedArrayMixin:
223223
_C_field_shape = 'shape'
224224
_C_field_size = 'size'
225225
_C_field_nbytes = 'nbytes'
226-
_C_field_sizeofelem = 'sizeofelem'
226+
_C_field_arity = 'arity'
227227

228228
_C_size_type = c_uint64
229229

@@ -233,7 +233,7 @@ class MappedArrayMixin:
233233
(_C_field_shape, POINTER(_C_size_type)),
234234
(_C_field_size, _C_size_type),
235235
(_C_field_nbytes, _C_size_type),
236-
(_C_field_sizeofelem, _C_size_type)]}))
236+
(_C_field_arity, _C_size_type)]}))
237237

238238

239239
class ArrayMapped(MappedArrayMixin, Array):

0 commit comments

Comments
 (0)