Skip to content

Commit d8ce9c8

Browse files
committed
compiler: Add FunctionMap type
1 parent 74c6f8e commit d8ce9c8

2 files changed

Lines changed: 68 additions & 3 deletions

File tree

devito/types/parallel.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@
1919
from devito.types.basic import Scalar, Symbol
2020
from devito.types.dimension import CustomDimension
2121
from devito.types.misc import Fence, VolatileInt
22+
from devito.types.object import LocalObject
2223

2324
__all__ = ['NThreads', 'NThreadsNested', 'NThreadsNonaffine', 'NThreadsBase',
2425
'DeviceID', 'ThreadID', 'Lock', 'ThreadArray', 'PThreadArray',
2526
'SharedData', 'NPThreads', 'DeviceRM', 'QueueID', 'Barrier', 'TBArray',
26-
'ThreadPoolSync', 'ThreadCommit', 'ThreadWait']
27+
'ThreadPoolSync', 'ThreadCommit', 'ThreadWait', 'FunctionMap']
2728

2829

2930
class NThreadsAbstract(Scalar):
@@ -364,3 +365,19 @@ def __init_finalize__(self, *args, **kwargs):
364365
kwargs['liveness'] = 'eager'
365366

366367
super().__init_finalize__(*args, **kwargs)
368+
369+
370+
class FunctionMap(LocalObject):
371+
372+
"""
373+
Wrap a Function in a LocalObject.
374+
"""
375+
376+
__rargs__ = ('name', 'tensor')
377+
378+
def __init__(self, name, tensor, **kwargs):
379+
super().__init__(name, **kwargs)
380+
self.tensor = tensor
381+
382+
def _hashable_content(self):
383+
return super()._hashable_content() + (self.tensor,)

tests/test_iet.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
from devito.passes.iet.engine import Graph
1818
from devito.passes.iet.languages.C import CDataManager
1919
from devito.symbolics import (Byref, FieldFromComposite, InlineIf, Macro, Class,
20-
String, FLOAT)
20+
String, ListInitializer, SizeOf, FLOAT)
2121
from devito.tools import CustomDtype, as_tuple, dtype_to_ctype
22-
from devito.types import CustomDimension, Array, LocalObject, Symbol, Pointer
22+
from devito.types import (
23+
CustomDimension, Array, LocalObject, Symbol, Pointer, FunctionMap
24+
)
2325

2426

2527
@pytest.fixture
@@ -298,6 +300,52 @@ def _C_free(self):
298300
}"""
299301

300302

303+
def test_make_cuda_tensor_map():
304+
305+
class CUTensorMap(FunctionMap):
306+
307+
dtype = CustomDtype('CUtensorMap')
308+
309+
@property
310+
def _C_init(self):
311+
symsizes = list(reversed(self.tensor.symbolic_shape))
312+
sizeof_dtype = SizeOf(self.tensor.dmap._C_typedata)
313+
314+
sizes = ListInitializer(symsizes)
315+
strides = ListInitializer([
316+
np.prod(symsizes[:i])*sizeof_dtype for i in range(1, len(symsizes))
317+
])
318+
319+
arguments = [
320+
Byref(self),
321+
Macro('CU_TENSOR_MAP_DATA_TYPE_FLOAT32'),
322+
4, self.tensor.dmap, sizes, strides,
323+
]
324+
call = Call('cuTensorMapEncodeTiled', arguments)
325+
326+
return call
327+
328+
grid = Grid(shape=(10, 10, 10))
329+
330+
u = TimeFunction(name='u', grid=grid)
331+
332+
tmap = CUTensorMap('tmap', u)
333+
334+
iet = Call('foo', tmap)
335+
iet = ElementalFunction('foo', iet, parameters=())
336+
dm = CDataManager(sregistry=None)
337+
iet = CDataManager.place_definitions.__wrapped__(dm, iet)[0]
338+
339+
assert str(iet) == """\
340+
static void foo()
341+
{
342+
CUtensorMap tmap;
343+
cuTensorMapEncodeTiled(&(tmap),CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]});
344+
345+
foo(tmap);
346+
}"""
347+
348+
301349
def test_cpp_local_object():
302350
"""
303351
Test C++ support for LocalObjects.

0 commit comments

Comments
 (0)