@@ -256,6 +256,12 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
256256
257257 arity_param = Symbol (name = 'arity' , dtype = size_t )
258258 arity_arg = SizeOf (obj .indexed ._C_typedata )
259+ ndims_param = Symbol (name = 'ndims' , dtype = size_t )
260+ ndims_arg = obj .ndim
261+ shape_param = Array (name = f'{ obj .name } _shape' , dtype = np .uint64 ,
262+ dimensions = (Dimension (name = 'd' ),), scope = 'rvalue' )
263+ shape_arg = ListInitializer (obj .c0 .symbolic_shape , dtype = shape_param .dtype )
264+
259265 ffp1 = FieldFromPointer (obj ._C_field_shape , obj ._C_symbol )
260266 ffp2 = FieldFromPointer (obj ._C_field_size , obj ._C_symbol )
261267 ffp3 = FieldFromPointer (obj ._C_field_nbytes , obj ._C_symbol )
@@ -271,11 +277,18 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
271277 nbytes = SizeOf (obj ._C_size_type )* obj .ndim
272278 alloc1 = self .langbb ['host-alloc' ](memptr , alignment , nbytes )
273279
274- # Initialize the Bundle struct
275- init = [* [DummyExpr (IndexedPointer (ffp1 , i ), s )
276- for i , s in enumerate (obj .c0 .symbolic_shape )],
277- DummyExpr (ffp2 , obj .size ),
278- DummyExpr (ffp3 , ffp2 * arity_param )]
280+ # Initialize the Bundle metadata
281+ dim , = shape_param .dimensions
282+ init = [DummyExpr (ffp2 , 1 )]
283+ init .append (Iteration (
284+ List (body = (
285+ DummyExpr (IndexedPointer (ffp1 , dim ), shape_param [dim ]),
286+ DummyExpr (ffp2 , ffp2 * shape_param [dim ])
287+ )),
288+ dim ,
289+ ndims_param - 1 ,
290+ ))
291+ init .append (DummyExpr (ffp3 , ffp2 * arity_param ))
279292
280293 # Free all of the allocated data
281294 frees = [self .langbb ['host-free' ](ffp1 ),
@@ -290,6 +303,8 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
290303 efunc0 = make_callable (name , body , retval = obj )
291304 args = list (efunc0 .parameters )
292305 args [args .index (arity_param )] = arity_arg
306+ args [args .index (ndims_param )] = ndims_arg
307+ args [args .index (shape_param )] = shape_arg
293308 alloc = Call (name , args , retobj = obj )
294309
295310 # Same story for the frees
@@ -298,6 +313,7 @@ def _alloc_bundle_struct_on_high_bw_mem(self, site, obj, storage):
298313 free = Call (name , efunc1 .parameters )
299314
300315 storage .update (obj , site , allocs = alloc , frees = free , efuncs = (efunc0 , efunc1 ))
316+ storage .include (self .langbb ['header-array' ])
301317
302318 def _alloc_object_array_on_low_lat_mem (self , site , obj , storage ):
303319 """
0 commit comments