|
17 | 17 | from devito.passes.iet.engine import Graph |
18 | 18 | from devito.passes.iet.languages.C import CDataManager |
19 | 19 | from devito.symbolics import (Byref, FieldFromComposite, InlineIf, Macro, Class, |
20 | | - String, FLOAT) |
| 20 | + String, ListInitializer, SizeOf, FLOAT) |
21 | 21 | from devito.tools import CustomDtype, as_tuple, dtype_to_ctype |
22 | | -from devito.types import CustomDimension, Array, LocalObject, Symbol, Pointer |
| 22 | +from devito.types import ( |
| 23 | + CustomDimension, Array, LocalObject, Symbol, Pointer, FunctionMap |
| 24 | +) |
23 | 25 |
|
24 | 26 |
|
25 | 27 | @pytest.fixture |
@@ -298,6 +300,52 @@ def _C_free(self): |
298 | 300 | }""" |
299 | 301 |
|
300 | 302 |
|
| 303 | +def test_make_cuda_tensor_map(): |
| 304 | + |
| 305 | + class CUTensorMap(FunctionMap): |
| 306 | + |
| 307 | + dtype = CustomDtype('CUtensorMap') |
| 308 | + |
| 309 | + @property |
| 310 | + def _C_init(self): |
| 311 | + symsizes = list(reversed(self.tensor.symbolic_shape)) |
| 312 | + sizeof_dtype = SizeOf(self.tensor.dmap._C_typedata) |
| 313 | + |
| 314 | + sizes = ListInitializer(symsizes) |
| 315 | + strides = ListInitializer([ |
| 316 | + np.prod(symsizes[:i])*sizeof_dtype for i in range(1, len(symsizes)) |
| 317 | + ]) |
| 318 | + |
| 319 | + arguments = [ |
| 320 | + Byref(self), |
| 321 | + Macro('CU_TENSOR_MAP_DATA_TYPE_FLOAT32'), |
| 322 | + 4, self.tensor.dmap, sizes, strides, |
| 323 | + ] |
| 324 | + call = Call('cuTensorMapEncodeTiled', arguments) |
| 325 | + |
| 326 | + return call |
| 327 | + |
| 328 | + grid = Grid(shape=(10, 10, 10)) |
| 329 | + |
| 330 | + u = TimeFunction(name='u', grid=grid) |
| 331 | + |
| 332 | + tmap = CUTensorMap('tmap', u) |
| 333 | + |
| 334 | + iet = Call('foo', tmap) |
| 335 | + iet = ElementalFunction('foo', iet, parameters=()) |
| 336 | + dm = CDataManager(sregistry=None) |
| 337 | + iet = CDataManager.place_definitions.__wrapped__(dm, iet)[0] |
| 338 | + |
| 339 | + assert str(iet) == """\ |
| 340 | +static void foo() |
| 341 | +{ |
| 342 | + CUtensorMap tmap; |
| 343 | + cuTensorMapEncodeTiled(&(tmap),CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]}); |
| 344 | +
|
| 345 | + foo(tmap); |
| 346 | +}""" |
| 347 | + |
| 348 | + |
301 | 349 | def test_cpp_local_object(): |
302 | 350 | """ |
303 | 351 | Test C++ support for LocalObjects. |
|
0 commit comments