Skip to content

Commit b21e8eb

Browse files
committed
compiler: Move and enhance FunctionMap
1 parent 22d82c3 commit b21e8eb

3 files changed

Lines changed: 29 additions & 22 deletions

File tree

devito/types/misc.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,30 @@ def closing(self):
381381
"""
382382

383383

384+
class FunctionMap(LocalObject):
385+
386+
"""
387+
Wrap a Function in a LocalObject.
388+
"""
389+
390+
__rargs__ = ('name', 'tensor')
391+
392+
def __init__(self, name, tensor, **kwargs):
393+
super().__init__(name, **kwargs)
394+
self.tensor = tensor
395+
396+
def _hashable_content(self):
397+
return super()._hashable_content() + (self.tensor,)
398+
399+
@property
400+
def free_symbols(self):
401+
"""
402+
The free symbols of a FunctionMap are the free symbols of the
403+
underlying Function.
404+
"""
405+
return super().free_symbols | {self.tensor}
406+
407+
384408
# *** C/CXX support types
385409

386410
size_t = CustomDtype('size_t')

devito/types/parallel.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,11 @@
2020
from devito.types.basic import Scalar, Symbol
2121
from devito.types.dimension import CustomDimension
2222
from devito.types.misc import Fence, VolatileInt
23-
from devito.types.object import LocalObject
2423

2524
__all__ = [
2625
'Barrier',
2726
'DeviceID',
2827
'DeviceRM',
29-
'FunctionMap',
3028
'Lock',
3129
'NPThreads',
3230
'NThreads',
@@ -386,19 +384,3 @@ def __init_finalize__(self, *args, **kwargs):
386384
kwargs['liveness'] = 'eager'
387385

388386
super().__init_finalize__(*args, **kwargs)
389-
390-
391-
class FunctionMap(LocalObject):
392-
393-
"""
394-
Wrap a Function in a LocalObject.
395-
"""
396-
397-
__rargs__ = ('name', 'tensor')
398-
399-
def __init__(self, name, tensor, **kwargs):
400-
super().__init__(name, **kwargs)
401-
self.tensor = tensor
402-
403-
def _hashable_content(self):
404-
return super()._hashable_content() + (self.tensor,)

tests/test_iet.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
)
2323
from devito.tools import CustomDtype, as_tuple, dtype_to_ctype
2424
from devito.types import (
25-
Array, CustomDimension, FunctionMap, LocalObject, Pointer, Symbol
25+
Array, CustomDimension, LocalObject, Pointer, Symbol
2626
)
27+
from devito.types.misc import FunctionMap
2728

2829

2930
@pytest.fixture
@@ -322,7 +323,7 @@ def _C_init(self):
322323
Byref(self),
323324
Macro('CU_TENSOR_MAP_DATA_TYPE_FLOAT32'),
324325
4, self.tensor.dmap, sizes, strides,
325-
]
326+
]
326327
call = Call('cuTensorMapEncodeTiled', arguments)
327328

328329
return call
@@ -342,10 +343,10 @@ def _C_init(self):
342343
static void foo()
343344
{
344345
CUtensorMap tmap;
345-
cuTensorMapEncodeTiled(&(tmap),CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]});
346+
cuTensorMapEncodeTiled(&tmap,CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]});
346347
347348
foo(tmap);
348-
}"""
349+
}""" # noqa
349350

350351

351352
def test_cpp_local_object():

0 commit comments

Comments
 (0)