|
17 | 17 | from devito.passes.iet.engine import Graph |
18 | 18 | from devito.passes.iet.languages.C import CDataManager |
19 | 19 | from devito.symbolics import ( |
20 | | - FLOAT, Byref, Class, FieldFromComposite, InlineIf, Macro, String |
| 20 | + FLOAT, Byref, Class, FieldFromComposite, InlineIf, ListInitializer, |
| 21 | + Macro, SizeOf, String |
21 | 22 | ) |
22 | 23 | from devito.tools import CustomDtype, as_tuple, dtype_to_ctype |
23 | | -from devito.types import Array, CustomDimension, LocalObject, Pointer, Symbol |
| 24 | +from devito.types import ( |
| 25 | + Array, CustomDimension, FunctionMap, LocalObject, Pointer, Symbol |
| 26 | +) |
24 | 27 |
|
25 | 28 |
|
26 | 29 | @pytest.fixture |
@@ -299,6 +302,52 @@ def _C_free(self): |
299 | 302 | }""" |
300 | 303 |
|
301 | 304 |
|
| 305 | +def test_make_cuda_tensor_map(): |
| 306 | + |
| 307 | + class CUTensorMap(FunctionMap): |
| 308 | + |
| 309 | + dtype = CustomDtype('CUtensorMap') |
| 310 | + |
| 311 | + @property |
| 312 | + def _C_init(self): |
| 313 | + symsizes = list(reversed(self.tensor.symbolic_shape)) |
| 314 | + sizeof_dtype = SizeOf(self.tensor.dmap._C_typedata) |
| 315 | + |
| 316 | + sizes = ListInitializer(symsizes) |
| 317 | + strides = ListInitializer([ |
| 318 | + np.prod(symsizes[:i])*sizeof_dtype for i in range(1, len(symsizes)) |
| 319 | + ]) |
| 320 | + |
| 321 | + arguments = [ |
| 322 | + Byref(self), |
| 323 | + Macro('CU_TENSOR_MAP_DATA_TYPE_FLOAT32'), |
| 324 | + 4, self.tensor.dmap, sizes, strides, |
| 325 | + ] |
| 326 | + call = Call('cuTensorMapEncodeTiled', arguments) |
| 327 | + |
| 328 | + return call |
| 329 | + |
| 330 | + grid = Grid(shape=(10, 10, 10)) |
| 331 | + |
| 332 | + u = TimeFunction(name='u', grid=grid) |
| 333 | + |
| 334 | + tmap = CUTensorMap('tmap', u) |
| 335 | + |
| 336 | + iet = Call('foo', tmap) |
| 337 | + iet = ElementalFunction('foo', iet, parameters=()) |
| 338 | + dm = CDataManager(sregistry=None) |
| 339 | + iet = CDataManager.place_definitions.__wrapped__(dm, iet)[0] |
| 340 | + |
| 341 | + assert str(iet) == """\ |
| 342 | +static void foo() |
| 343 | +{ |
| 344 | + CUtensorMap tmap; |
| 345 | + cuTensorMapEncodeTiled(&(tmap),CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]}); |
| 346 | +
|
| 347 | + foo(tmap); |
| 348 | +}""" |
| 349 | + |
| 350 | + |
302 | 351 | def test_cpp_local_object(): |
303 | 352 | """ |
304 | 353 | Test C++ support for LocalObjects. |
|
0 commit comments