Skip to content

Commit abf5a2c

Browse files
committed
compiler: Move and enhance FunctionMap
1 parent e407393 commit abf5a2c

3 files changed

Lines changed: 30 additions & 22 deletions

File tree

devito/types/misc.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,30 @@ def closing(self):
358358
"""
359359

360360

361+
class FunctionMap(LocalObject):
362+
363+
"""
364+
Wrap a Function in a LocalObject.
365+
"""
366+
367+
__rargs__ = ('name', 'tensor')
368+
369+
def __init__(self, name, tensor, **kwargs):
370+
super().__init__(name, **kwargs)
371+
self.tensor = tensor
372+
373+
def _hashable_content(self):
374+
return super()._hashable_content() + (self.tensor,)
375+
376+
@property
377+
def free_symbols(self):
378+
"""
379+
The free symbols of a FunctionMap are the free symbols of the
380+
underlying Function.
381+
"""
382+
return super().free_symbols | {self.tensor}
383+
384+
361385
# *** C/CXX support types
362386

363387
size_t = CustomDtype('size_t')

devito/types/parallel.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919
from devito.types.basic import Scalar, Symbol
2020
from devito.types.dimension import CustomDimension
2121
from devito.types.misc import Fence, VolatileInt
22-
from devito.types.object import LocalObject
2322

2423
__all__ = ['NThreads', 'NThreadsNested', 'NThreadsNonaffine', 'NThreadsBase',
2524
'DeviceID', 'ThreadID', 'Lock', 'ThreadArray', 'PThreadArray',
2625
'SharedData', 'NPThreads', 'DeviceRM', 'QueueID', 'Barrier', 'TBArray',
27-
'ThreadPoolSync', 'ThreadCommit', 'ThreadWait', 'FunctionMap']
26+
'ThreadPoolSync', 'ThreadCommit', 'ThreadWait']
2827

2928

3029
class NThreadsAbstract(Scalar):
@@ -365,19 +364,3 @@ def __init_finalize__(self, *args, **kwargs):
365364
kwargs['liveness'] = 'eager'
366365

367366
super().__init_finalize__(*args, **kwargs)
368-
369-
370-
class FunctionMap(LocalObject):
371-
372-
"""
373-
Wrap a Function in a LocalObject.
374-
"""
375-
376-
__rargs__ = ('name', 'tensor')
377-
378-
def __init__(self, name, tensor, **kwargs):
379-
super().__init__(name, **kwargs)
380-
self.tensor = tensor
381-
382-
def _hashable_content(self):
383-
return super()._hashable_content() + (self.tensor,)

tests/test_iet.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
String, ListInitializer, SizeOf, FLOAT)
2121
from devito.tools import CustomDtype, as_tuple, dtype_to_ctype
2222
from devito.types import (
23-
CustomDimension, Array, LocalObject, Symbol, Pointer, FunctionMap
23+
CustomDimension, Array, LocalObject, Symbol, Pointer
2424
)
25+
from devito.types.misc import FunctionMap
2526

2627

2728
@pytest.fixture
@@ -320,7 +321,7 @@ def _C_init(self):
320321
Byref(self),
321322
Macro('CU_TENSOR_MAP_DATA_TYPE_FLOAT32'),
322323
4, self.tensor.dmap, sizes, strides,
323-
]
324+
]
324325
call = Call('cuTensorMapEncodeTiled', arguments)
325326

326327
return call
@@ -340,10 +341,10 @@ def _C_init(self):
340341
static void foo()
341342
{
342343
CUtensorMap tmap;
343-
cuTensorMapEncodeTiled(&(tmap),CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]});
344+
cuTensorMapEncodeTiled(&tmap,CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]});
344345
345346
foo(tmap);
346-
}"""
347+
}""" # noqa
347348

348349

349350
def test_cpp_local_object():

0 commit comments

Comments
 (0)