99
1010import numpy as np
1111
12- from devito .ir import (Block , Call , Definition , DummyExpr , Return , EntryFunction ,
13- FindNodes , FindSymbols , MapExprStmts , Transformer ,
14- make_callable )
12+ from devito .ir import (
13+ Block , Call , Definition , DummyExpr , Iteration , List , Return , EntryFunction ,
14+ FindNodes , FindSymbols , MapExprStmts , Transformer , make_callable
15+ )
1516from devito .passes import is_gpu_create
1617from devito .passes .iet .engine import iet_pass
1718from devito .passes .iet .langbase import LangBB
18- from devito .symbolics import (Byref , DefFunction , FieldFromPointer , IndexedPointer ,
19- SizeOf , VOID , pow_to_mul , unevaluate )
19+ from devito .symbolics import (
20+ Byref , DefFunction , FieldFromPointer , IndexedPointer , ListInitializer ,
21+ SizeOf , VOID , pow_to_mul , unevaluate
22+ )
2023from devito .tools import as_mapper , as_list , as_tuple , filter_sorted , flatten
21- from devito .types import (Array , ComponentAccess , CustomDimension , DeviceMap ,
22- DeviceRM , Eq , Symbol , size_t )
24+ from devito .types import (
25+ Array , ComponentAccess , CustomDimension , Dimension , DeviceMap , DeviceRM ,
26+ Eq , Symbol , size_t
27+ )
2328
2429__all__ = ['DataManager' , 'DeviceAwareDataManager' , 'Storage' ]
2530
@@ -40,6 +45,7 @@ def __init__(self, *args, **kwargs):
4045 super ().__init__ (* args , ** kwargs )
4146
4247 self .defined = set ()
48+ self .includes = set ()
4349
4450 def update (self , key , site , ** kwargs ):
4551 if key in self .defined :
@@ -67,6 +73,10 @@ def map(self, key, site, k, v):
6773
6874 self .defined .add ((site [- 1 ], key ))
6975
76+ def include (self , v ):
77+ if v :
78+ self .includes .add (v )
79+
7080
7181class DataManager :
7282
@@ -132,8 +142,7 @@ def _alloc_array_on_global_mem(self, site, obj, storage):
132142 alloc = Call (name , efunc .parameters )
133143
134144 storage .update (obj , site , allocs = alloc , efuncs = efunc )
135-
136- return self .langbb ['header-memcpy' ]
145+ storage .include (self .langbb ['header-memcpy' ])
137146
138147 def _alloc_scalar_on_low_lat_mem (self , site , expr , storage ):
139148 """
@@ -170,6 +179,12 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
170179
171180 arity_param = Symbol (name = 'arity' , dtype = size_t )
172181 arity_arg = SizeOf (obj .indexed ._C_typedata )
182+ ndims_param = Symbol (name = 'ndims' , dtype = size_t )
183+ ndims_arg = obj .ndim
184+ shape_param = Array (name = f'{ obj .name } _shape' , dtype = np .uint64 ,
185+ dimensions = (Dimension (name = 'd' ),), scope = 'rvalue' )
186+ shape_arg = ListInitializer (obj .c0 .symbolic_shape , dtype = shape_param .dtype )
187+
173188 ffp0 = FieldFromPointer (obj ._C_field_data , obj ._C_symbol )
174189 ffp1 = FieldFromPointer (obj ._C_field_shape , obj ._C_symbol )
175190 ffp2 = FieldFromPointer (obj ._C_field_size , obj ._C_symbol )
@@ -183,14 +198,21 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
183198
184199 # Allocate the shape array
185200 memptr = VOID (Byref (ffp1 ), '**' )
186- nbytes = SizeOf (obj ._C_size_type )* obj . ndim
201+ nbytes = SizeOf (obj ._C_size_type )* ndims_param
187202 alloc1 = self .langbb ['host-alloc' ](memptr , alignment , nbytes )
188203
189- # Initialize the Array struct
190- init = [* [DummyExpr (IndexedPointer (ffp1 , i ), s )
191- for i , s in enumerate (obj .c0 .symbolic_shape )],
192- DummyExpr (ffp2 , obj .size ),
193- DummyExpr (ffp3 , ffp2 * arity_param )]
204+ # Initialize the Array metadata
205+ dim , = shape_param .dimensions
206+ init = [DummyExpr (ffp2 , 1 )]
207+ init .append (Iteration (
208+ List (body = (
209+ DummyExpr (IndexedPointer (ffp1 , dim ), shape_param [dim ]),
210+ DummyExpr (ffp2 , ffp2 * shape_param [dim ])
211+ )),
212+ dim ,
213+ ndims_param - 1 ,
214+ ))
215+ init .append (DummyExpr (ffp3 , ffp2 * arity_param ))
194216
195217 # Allocate the underlying host data
196218 memptr = VOID (Byref (ffp0 ), '**' )
@@ -213,6 +235,8 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
213235 efunc0 = make_callable (name , body , retval = obj )
214236 args = list (efunc0 .parameters )
215237 args [args .index (arity_param )] = arity_arg
238+ args [args .index (ndims_param )] = ndims_arg
239+ args [args .index (shape_param )] = shape_arg
216240 alloc = Call (name , args , retobj = obj )
217241
218242 # Same story for the frees
@@ -222,6 +246,7 @@ def _alloc_mapped_array_on_high_bw_mem(self, site, obj, storage, *args):
222246 free = Call (name , efunc1 .parameters )
223247
224248 storage .update (obj , site , allocs = alloc , frees = free , efuncs = (efunc0 , efunc1 ))
249+ storage .include (self .langbb ['header-array' ])
225250
226251 def _alloc_bundle_struct_on_high_bw_mem (self , site , obj , storage ):
227252 """
@@ -436,18 +461,15 @@ def place_definitions(self, iet, globs=None, **kwargs):
436461 self ._alloc_pointed_array_on_high_bw_mem (iet , i , storage )
437462
438463 # Handle postponed global objects
439- includes = set ()
440464 if isinstance (iet , EntryFunction ) and globs :
441465 for i in sorted (globs , key = lambda f : f .name ):
442- v = self ._alloc_array_on_global_mem (iet , i , storage )
443- if v :
444- includes .add (v )
466+ self ._alloc_array_on_global_mem (iet , i , storage )
445467
446468 iet , efuncs = self ._inject_definitions (iet , storage )
447469
448470 return iet , {'efuncs' : efuncs ,
449471 'globals' : as_tuple (globs ),
450- 'includes' : as_tuple (includes )}
472+ 'includes' : as_tuple (sorted ( storage . includes ) )}
451473
452474 @iet_pass
453475 def place_casts (self , iet , ** kwargs ):
0 commit comments