Skip to content

Commit 0eeff32

Browse files
committed
compiler: Pass ctx down to _map_function_on_high_bw_mem
1 parent 8107646 commit 0eeff32

1 file changed

Lines changed: 26 additions & 26 deletions

File tree

devito/passes/iet/definitions.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,7 @@ class DeviceAwareDataManager(DataManager):
563563
def __init__(self, options=None, **kwargs):
564564
self.gpu_fit = options['gpu-fit']
565565
self.gpu_create = options['gpu-create']
566-
self.pmode = options.get('place-transfers')
566+
self.gpu_place_transfers = options.get('place-transfers')
567567

568568
super().__init__(**kwargs)
569569

@@ -596,7 +596,8 @@ def _map_array_on_high_bw_mem(self, site, obj, storage):
596596

597597
storage.update(obj, site, maps=mmap, unmaps=unmap)
598598

599-
def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=False):
599+
def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm,
600+
read_only=False, **kwargs):
600601
"""
601602
Map a Function already defined in the host memory in to the device high
602603
bandwidth memory.
@@ -629,42 +630,41 @@ def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=F
629630
storage.update(obj, site, maps=mmap, unmaps=unmap, efuncs=efuncs)
630631

631632
@iet_pass
632-
def place_transfers(self, iet, data_movs=None, **kwargs):
633+
def place_transfers(self, iet, data_movs=None, ctx=None, **kwargs):
633634
"""
634635
Create a new IET with host-device data transfers. This requires mapping
635636
symbols to the suitable memory spaces.
636637
"""
637-
if not self.pmode:
638+
if not self.gpu_place_transfers:
638639
return iet, {}
639640

640-
@singledispatch
641-
def _place_transfers(iet, data_movs):
641+
if not isinstance(iet, EntryFunction):
642642
return iet, {}
643643

644-
@_place_transfers.register(EntryFunction)
645-
def _(iet, data_movs):
646-
reads, writes = data_movs
644+
reads, writes = data_movs
647645

648-
# Special symbol which gives user code control over data deallocations
649-
devicerm = DeviceRM()
646+
# Special symbol which gives user code control over data deallocations
647+
devicerm = DeviceRM()
650648

651-
storage = Storage()
652-
for i in filter_sorted(writes):
653-
if i.is_Array:
654-
self._map_array_on_high_bw_mem(iet, i, storage)
655-
else:
656-
self._map_function_on_high_bw_mem(iet, i, storage, devicerm)
657-
for i in filter_sorted(reads - writes):
658-
if i.is_Array:
659-
self._map_array_on_high_bw_mem(iet, i, storage)
660-
else:
661-
self._map_function_on_high_bw_mem(iet, i, storage, devicerm, True)
662-
663-
iet, efuncs = self._inject_definitions(iet, storage)
649+
storage = Storage()
650+
for i in filter_sorted(writes):
651+
if i.is_Array:
652+
self._map_array_on_high_bw_mem(iet, i, storage)
653+
else:
654+
self._map_function_on_high_bw_mem(
655+
iet, i, storage, devicerm, ctx=ctx
656+
)
657+
for i in filter_sorted(reads - writes):
658+
if i.is_Array:
659+
self._map_array_on_high_bw_mem(iet, i, storage)
660+
else:
661+
self._map_function_on_high_bw_mem(
662+
iet, i, storage, devicerm, read_only=True, ctx=ctx
663+
)
664664

665-
return iet, {'efuncs': efuncs}
665+
iet, efuncs = self._inject_definitions(iet, storage)
666666

667-
return _place_transfers(iet, data_movs=data_movs)
667+
return iet, {'efuncs': efuncs}
668668

669669
@iet_pass
670670
def place_devptr(self, iet, **kwargs):

0 commit comments

Comments
 (0)