Skip to content

Commit c5891d2

Browse files
committed
compiler: Move and enhance FunctionMap
1 parent 59ff1bf commit c5891d2

3 files changed

Lines changed: 30 additions & 23 deletions

File tree

devito/types/misc.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,30 @@ def closing(self):
345345
"""
346346

347347

348+
class FunctionMap(LocalObject):
349+
350+
"""
351+
Wrap a Function in a LocalObject.
352+
"""
353+
354+
__rargs__ = ('name', 'tensor')
355+
356+
def __init__(self, name, tensor, **kwargs):
357+
super().__init__(name, **kwargs)
358+
self.tensor = tensor
359+
360+
def _hashable_content(self):
361+
return super()._hashable_content() + (self.tensor,)
362+
363+
@property
364+
def free_symbols(self):
365+
"""
366+
The free symbols of a FunctionMap are the free symbols of the
367+
underlying Function.
368+
"""
369+
return super().free_symbols | {self.tensor}
370+
371+
348372
# *** C/CXX support types
349373

350374
size_t = CustomDtype('size_t')

devito/types/parallel.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@
1919
from devito.types.basic import Scalar, Symbol
2020
from devito.types.dimension import CustomDimension
2121
from devito.types.misc import Fence, VolatileInt
22-
from devito.types.object import LocalObject
2322

2423
__all__ = ['NThreads', 'NThreadsNested', 'NThreadsNonaffine', 'NThreadsBase',
2524
'DeviceID', 'ThreadID', 'Lock', 'ThreadArray', 'PThreadArray',
26-
'SharedData', 'NPThreads', 'DeviceRM', 'QueueID', 'Barrier', 'TBArray',
27-
'FunctionMap']
25+
'SharedData', 'NPThreads', 'DeviceRM', 'QueueID', 'Barrier', 'TBArray']
2826

2927

3028
class NThreadsAbstract(Scalar):
@@ -333,19 +331,3 @@ def __init_finalize__(self, *args, **kwargs):
333331
kwargs['liveness'] = 'eager'
334332

335333
super().__init_finalize__(*args, **kwargs)
336-
337-
338-
class FunctionMap(LocalObject):
339-
340-
"""
341-
Wrap a Function in a LocalObject.
342-
"""
343-
344-
__rargs__ = ('name', 'tensor')
345-
346-
def __init__(self, name, tensor, **kwargs):
347-
super().__init__(name, **kwargs)
348-
self.tensor = tensor
349-
350-
def _hashable_content(self):
351-
return super()._hashable_content() + (self.tensor,)

tests/test_iet.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from devito.symbolics import (Byref, FieldFromComposite, InlineIf, Macro, Class,
1818
FLOAT, ListInitializer, SizeOf)
1919
from devito.tools import CustomDtype, as_tuple, dtype_to_ctype
20-
from devito.types import Array, LocalObject, Symbol, FunctionMap
20+
from devito.types import Array, LocalObject, Symbol
21+
from devito.types.misc import FunctionMap
2122

2223

2324
@pytest.fixture
@@ -316,7 +317,7 @@ def _C_init(self):
316317
Byref(self),
317318
Macro('CU_TENSOR_MAP_DATA_TYPE_FLOAT32'),
318319
4, self.tensor.dmap, sizes, strides,
319-
]
320+
]
320321
call = Call('cuTensorMapEncodeTiled', arguments)
321322

322323
return call
@@ -336,10 +337,10 @@ def _C_init(self):
336337
static void foo()
337338
{
338339
CUtensorMap tmap;
339-
cuTensorMapEncodeTiled(&(tmap),CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]});
340+
cuTensorMapEncodeTiled(&tmap,CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]});
340341
341342
foo(tmap);
342-
}"""
343+
}""" # noqa
343344

344345

345346
def test_cpp_local_object():

0 commit comments

Comments
 (0)