Skip to content

Commit 37f2c14

Browse files
committed
compiler: Add FunctionMap type
1 parent 1c65e48 commit 37f2c14

2 files changed

Lines changed: 69 additions & 2 deletions

File tree

devito/types/parallel.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
from devito.types.basic import Scalar, Symbol
2121
from devito.types.dimension import CustomDimension
2222
from devito.types.misc import Fence, VolatileInt
23+
from devito.types.object import LocalObject
2324

2425
__all__ = [
2526
'Barrier',
2627
'DeviceID',
2728
'DeviceRM',
29+
'FunctionMap',
2830
'Lock',
2931
'NPThreads',
3032
'NThreads',
@@ -384,3 +386,19 @@ def __init_finalize__(self, *args, **kwargs):
384386
kwargs['liveness'] = 'eager'
385387

386388
super().__init_finalize__(*args, **kwargs)
389+
390+
391+
class FunctionMap(LocalObject):
392+
393+
"""
394+
Wrap a Function in a LocalObject.
395+
"""
396+
397+
__rargs__ = ('name', 'tensor')
398+
399+
def __init__(self, name, tensor, **kwargs):
400+
super().__init__(name, **kwargs)
401+
self.tensor = tensor
402+
403+
def _hashable_content(self):
404+
return super()._hashable_content() + (self.tensor,)

tests/test_iet.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@
1717
from devito.passes.iet.engine import Graph
1818
from devito.passes.iet.languages.C import CDataManager
1919
from devito.symbolics import (
20-
FLOAT, Byref, Class, FieldFromComposite, InlineIf, Macro, String
20+
FLOAT, Byref, Class, FieldFromComposite, InlineIf, ListInitializer,
21+
Macro, SizeOf, String
2122
)
2223
from devito.tools import CustomDtype, as_tuple, dtype_to_ctype
23-
from devito.types import Array, CustomDimension, LocalObject, Pointer, Symbol
24+
from devito.types import (
25+
Array, CustomDimension, FunctionMap, LocalObject, Pointer, Symbol
26+
)
2427

2528

2629
@pytest.fixture
@@ -299,6 +302,52 @@ def _C_free(self):
299302
}"""
300303

301304

305+
def test_make_cuda_tensor_map():
306+
307+
class CUTensorMap(FunctionMap):
308+
309+
dtype = CustomDtype('CUtensorMap')
310+
311+
@property
312+
def _C_init(self):
313+
symsizes = list(reversed(self.tensor.symbolic_shape))
314+
sizeof_dtype = SizeOf(self.tensor.dmap._C_typedata)
315+
316+
sizes = ListInitializer(symsizes)
317+
strides = ListInitializer([
318+
np.prod(symsizes[:i])*sizeof_dtype for i in range(1, len(symsizes))
319+
])
320+
321+
arguments = [
322+
Byref(self),
323+
Macro('CU_TENSOR_MAP_DATA_TYPE_FLOAT32'),
324+
4, self.tensor.dmap, sizes, strides,
325+
]
326+
call = Call('cuTensorMapEncodeTiled', arguments)
327+
328+
return call
329+
330+
grid = Grid(shape=(10, 10, 10))
331+
332+
u = TimeFunction(name='u', grid=grid)
333+
334+
tmap = CUTensorMap('tmap', u)
335+
336+
iet = Call('foo', tmap)
337+
iet = ElementalFunction('foo', iet, parameters=())
338+
dm = CDataManager(sregistry=None)
339+
iet = CDataManager.place_definitions.__wrapped__(dm, iet)[0]
340+
341+
assert str(iet) == """\
342+
static void foo()
343+
{
344+
CUtensorMap tmap;
345+
cuTensorMapEncodeTiled(&(tmap),CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]});
346+
347+
foo(tmap);
348+
}"""
349+
350+
302351
def test_cpp_local_object():
303352
"""
304353
Test C++ support for LocalObjects.

0 commit comments

Comments
 (0)