@@ -563,7 +563,7 @@ class DeviceAwareDataManager(DataManager):
563563 def __init__ (self , options = None , ** kwargs ):
564564 self .gpu_fit = options ['gpu-fit' ]
565565 self .gpu_create = options ['gpu-create' ]
566- self .pmode = options .get ('place-transfers' )
566+ self .gpu_place_transfers = options .get ('place-transfers' )
567567
568568 super ().__init__ (** kwargs )
569569
@@ -596,7 +596,8 @@ def _map_array_on_high_bw_mem(self, site, obj, storage):
596596
597597 storage .update (obj , site , maps = mmap , unmaps = unmap )
598598
599- def _map_function_on_high_bw_mem (self , site , obj , storage , devicerm , read_only = False ):
599+ def _map_function_on_high_bw_mem (self , site , obj , storage , devicerm ,
600+ read_only = False , ** kwargs ):
600601 """
601602 Map a Function already defined in the host memory in to the device high
602603 bandwidth memory.
@@ -629,42 +630,41 @@ def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=F
629630 storage .update (obj , site , maps = mmap , unmaps = unmap , efuncs = efuncs )
630631
631632 @iet_pass
632- def place_transfers (self , iet , data_movs = None , ** kwargs ):
633+ def place_transfers (self , iet , data_movs = None , ctx = None , ** kwargs ):
633634 """
634635 Create a new IET with host-device data transfers. This requires mapping
635636 symbols to the suitable memory spaces.
636637 """
637- if not self .pmode :
638+ if not self .gpu_place_transfers :
638639 return iet , {}
639640
640- @singledispatch
641- def _place_transfers (iet , data_movs ):
641+ if not isinstance (iet , EntryFunction ):
642642 return iet , {}
643643
644- @_place_transfers .register (EntryFunction )
645- def _ (iet , data_movs ):
646- reads , writes = data_movs
644+ reads , writes = data_movs
647645
648- # Special symbol which gives user code control over data deallocations
649- devicerm = DeviceRM ()
646+ # Special symbol which gives user code control over data deallocations
647+ devicerm = DeviceRM ()
650648
651- storage = Storage ()
652- for i in filter_sorted (writes ):
653- if i .is_Array :
654- self ._map_array_on_high_bw_mem (iet , i , storage )
655- else :
656- self ._map_function_on_high_bw_mem (iet , i , storage , devicerm )
657- for i in filter_sorted (reads - writes ):
658- if i .is_Array :
659- self ._map_array_on_high_bw_mem (iet , i , storage )
660- else :
661- self ._map_function_on_high_bw_mem (iet , i , storage , devicerm , True )
662-
663- iet , efuncs = self ._inject_definitions (iet , storage )
649+ storage = Storage ()
650+ for i in filter_sorted (writes ):
651+ if i .is_Array :
652+ self ._map_array_on_high_bw_mem (iet , i , storage )
653+ else :
654+ self ._map_function_on_high_bw_mem (
655+ iet , i , storage , devicerm , ctx = ctx
656+ )
657+ for i in filter_sorted (reads - writes ):
658+ if i .is_Array :
659+ self ._map_array_on_high_bw_mem (iet , i , storage )
660+ else :
661+ self ._map_function_on_high_bw_mem (
662+ iet , i , storage , devicerm , read_only = True , ctx = ctx
663+ )
664664
665- return iet , { 'efuncs' : efuncs }
665+ iet , efuncs = self . _inject_definitions ( iet , storage )
666666
667- return _place_transfers ( iet , data_movs = data_movs )
667+ return iet , { 'efuncs' : efuncs }
668668
669669 @iet_pass
670670 def place_devptr (self , iet , ** kwargs ):
0 commit comments