From 72a211111388a9dd8477fa9e887b9fc46a13a62e Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Sat, 13 Jun 2026 11:30:28 -0400 Subject: [PATCH] [TIRx] Post-bringup follow-ups: op-dispatch, namespaces, launch bounds, gemm-async, backend reorg Batch of post-bringup TIRx follow-ups, rebased onto current main: - Per-call exec scope via Tx..op; remove ExecScopeStmt - Split TIRx op namespaces; remove tile-primitive kind attrs - Support explicit CUDA launch bounds - gemm-async: support contiguous-axis (K-major) operand slicing - Move in-tree GPU backends out of core into src/backend// and python/tvm/backend// --- CMakeLists.txt | 22 +- cmake/modules/CUDA.cmake | 4 +- cmake/modules/Hexagon.cmake | 63 +- cmake/modules/LLVM.cmake | 6 +- cmake/modules/Metal.cmake | 4 +- cmake/modules/OpenCL.cmake | 4 +- cmake/modules/ROCM.cmake | 2 +- cmake/modules/Vulkan.cmake | 14 +- docs/reference/api/python/tirx/backend.rst | 6 - include/tvm/tirx/builtin.h | 270 - include/tvm/tirx/op.h | 2 - include/tvm/tirx/target_builtin/cuda.h | 745 --- include/tvm/tirx/target_builtin/trn.h | 156 - python/tvm/__init__.py | 8 +- python/tvm/backend/__init__.py | 209 + python/tvm/backend/adreno/__init__.py | 39 + .../adreno/target_tags.py} | 2 +- python/tvm/backend/cuda/__init__.py | 71 + python/tvm/backend/cuda/lang/__init__.py | 70 + python/tvm/backend/cuda/lang/alloc_pool.py | 529 ++ python/tvm/backend/cuda/lang/pipeline.py | 244 + python/tvm/backend/cuda/lang/smem_desc.py | 55 + .../tvm/backend/cuda/lang/tile_scheduler.py | 816 +++ python/tvm/backend/cuda/lang/warp_role.py | 144 + python/tvm/backend/cuda/op.py | 4283 +++++++++++++ .../cuda/operator}/__init__.py | 4 +- .../cuda/operator/intrinsics}/__init__.py | 0 .../cuda}/operator/intrinsics/_schema.py | 4 +- .../cuda/operator/intrinsics}/cp_async.py | 7 +- .../cuda/operator/intrinsics}/header.py | 0 .../cuda/operator/intrinsics}/math.py | 4 +- .../cuda/operator/intrinsics}/memory.py | 4 +- .../cuda/operator/intrinsics}/misc.py | 4 +- .../cuda/operator/intrinsics}/mma.py | 2 +- .../cuda/operator/intrinsics}/nvshmem.py | 2 +- .../cuda/operator/intrinsics}/registry.py | 0 .../cuda/operator/intrinsics}/sync.py | 10 +- .../cuda/operator/intrinsics}/tcgen05.py | 2 +- .../cuda/operator/intrinsics}/types.py | 0 .../cuda/operator/intrinsics}/utils.py | 0 .../cuda/operator/intrinsics}/wgmma.py | 2 +- .../cuda/operator/tile_primitive}/__init__.py | 4 + .../cuda/operator/tile_primitive}/common.py | 0 .../operator/tile_primitive}/copy/__init__.py | 0 .../operator/tile_primitive}/copy/_common.py | 0 .../tile_primitive}/copy/_swizzle_iter.py | 0 .../operator/tile_primitive}/copy/fallback.py | 0 .../tile_primitive}/copy/gmem_smem.py | 2 +- .../tile_primitive}/copy/ld_stmatrix.py | 2 +- .../cuda/operator/tile_primitive}/copy/reg.py | 2 +- .../operator/tile_primitive}/copy/utils.py | 0 .../tile_primitive}/copy_async/__init__.py | 0 .../tile_primitive}/copy_async/dsmem.py | 0 .../tile_primitive}/copy_async/ldgsts.py | 2 +- .../tile_primitive}/copy_async/tcgen05_cp.py | 2 +- .../copy_async/tcgen05_ldst.py | 0 .../tile_primitive}/copy_async/tma.py | 0 .../tile_primitive}/copy_async/utils.py | 0 .../tile_primitive}/elementwise/__init__.py | 2 +- .../tile_primitive}/elementwise/_common.py | 0 .../elementwise/ops/__init__.py | 0 .../tile_primitive}/elementwise/ops/binary.py | 0 .../tile_primitive}/elementwise/ops/cast.py | 0 .../tile_primitive}/elementwise/ops/fma.py | 0 .../tile_primitive}/elementwise/ops/unary.py | 0 .../tile_primitive}/elementwise/reg.py | 2 +- .../tile_primitive}/elementwise/register.py | 0 .../tile_primitive}/elementwise/smem.py | 2 +- .../elementwise/vec_emit/__init__.py | 0 .../elementwise/vec_emit/binary_f32x2.py | 0 .../elementwise/vec_emit/cast_vec2.py | 0 .../elementwise/vec_emit/fma_f32x2.py | 0 .../tile_primitive}/exec_scope_utils.py | 0 .../operator/tile_primitive}/gemm/__init__.py | 0 .../tile_primitive}/gemm/mma_m16n8k_.py | 0 .../tile_primitive}/gemm_async/__init__.py | 0 .../tile_primitive}/gemm_async/tcgen05.py | 211 +- .../operator/tile_primitive}/gemm_utils.py | 0 .../operator/tile_primitive}/layout_utils.py | 0 .../permute_layout/__init__.py | 0 .../permute_layout/warp_xor_swizzle.py | 0 .../tile_primitive}/reduction/__init__.py | 0 .../tile_primitive}/reduction/local.py | 4 +- .../tile_primitive}/reduction/shared.py | 2 +- .../tile_primitive}/reduction/sm100_packed.py | 2 +- .../tile_primitive}/reduction/utils.py | 9 +- .../operator/tile_primitive}/tma_utils.py | 0 python/tvm/backend/cuda/script.py | 571 ++ .../cuda.py => backend/cuda/target_tags.py} | 2 +- python/tvm/backend/hexagon/__init__.py | 35 + .../hexagon/target_tags.py} | 2 +- python/tvm/backend/metal/__init__.py | 58 + python/tvm/backend/metal/op.py | 84 + python/tvm/backend/metal/script.py | 55 + .../metal.py => backend/metal/target_tags.py} | 4 +- python/tvm/backend/opencl/__init__.py | 25 + python/tvm/backend/rocm/__init__.py | 25 + python/tvm/backend/trn/__init__.py | 68 + python/tvm/backend/trn/layout.py | 123 + python/tvm/backend/trn/op.py | 153 + python/tvm/backend/trn/operator/__init__.py | 22 + .../trn/operator/tile_primitive}/__init__.py | 0 .../tile_primitive}/binary/__init__.py | 0 .../tile_primitive}/binary/default.py | 4 +- .../operator/tile_primitive}/binary/utils.py | 7 +- .../trn/operator/tile_primitive}/common.py | 0 .../tile_primitive}/compose_op/__init__.py | 0 .../compose_op/binary_chain.py | 0 .../compose_op/binary_reduce.py | 0 .../tile_primitive}/compose_op/compose_op.py | 0 .../compose_op/reduce_negate.py | 0 .../compose_op/unary_reduce.py | 0 .../tile_primitive}/compose_op/utils.py | 3 +- .../operator/tile_primitive}/copy/__init__.py | 0 .../operator/tile_primitive}/copy/default.py | 11 +- .../trn/operator/tile_primitive}/dim_utils.py | 0 .../operator/tile_primitive}/gemm/__init__.py | 0 .../operator/tile_primitive}/gemm/default.py | 11 +- .../tile_primitive}/instruction_generator.py | 11 +- .../operator/tile_primitive}/private_alloc.py | 6 +- .../tile_primitive}/reduction/__init__.py | 0 .../tile_primitive}/reduction/default.py | 2 +- .../tile_primitive}/reduction/utils.py | 11 +- .../tile_primitive}/select/__init__.py | 0 .../tile_primitive}/select/default.py | 5 +- .../tile_primitive}/unary/__init__.py | 0 .../operator/tile_primitive}/unary/default.py | 4 +- .../operator/tile_primitive}/unary/utils.py | 9 +- .../tile_primitive}/unary/with_bias_scale.py | 4 +- .../tile_primitive}/workspace_utils.py | 0 python/tvm/backend/trn/pipeline.py | 58 + python/tvm/backend/trn/script.py | 58 + python/tvm/backend/trn/target_tags.py | 36 + .../trn => backend/trn/transform}/__init__.py | 27 +- .../trn/transform}/naive_allocator.py | 3 +- .../trn/transform}/private_buffer_alloc.py | 0 python/tvm/backend/vulkan/__init__.py | 25 + python/tvm/backend/webgpu/__init__.py | 25 + python/tvm/s_tir/tensor_intrin/cuda.py | 4 +- python/tvm/s_tir/tensor_intrin/metal.py | 8 +- python/tvm/target/tag_registry/__init__.py | 4 - python/tvm/tirx/__init__.py | 17 +- python/tvm/tirx/backend/__init__.py | 4 +- python/tvm/tirx/bench.py | 8 +- python/tvm/tirx/compilation_pipeline.py | 39 +- python/tvm/tirx/lang/alloc_pool.py | 527 +- python/tvm/tirx/lang/pipeline.py | 242 +- python/tvm/tirx/lang/smem_desc.py | 53 +- python/tvm/tirx/lang/tile_scheduler.py | 814 +-- python/tvm/tirx/lang/warp_role.py | 142 +- python/tvm/tirx/op.py | 5285 +---------------- .../tvm/tirx/operator/intrinsics/_common.py | 2 +- .../tirx/operator/tile_primitive/__init__.py | 11 +- .../tile_primitive/dispatch_context.py | 4 + python/tvm/tirx/script/builder/ir.py | 656 +- python/tvm/tirx/transform/__init__.py | 1 - python/tvm/tirx/transform/transform.py | 11 + src/arith/const_int_bound.cc | 7 +- src/arith/ir_mutator_with_analyzer.cc | 9 +- src/arith/ir_visitor_with_analyzer.cc | 5 +- src/arith/rewrite_simplify.cc | 11 +- .../cuda/codegen}/codegen_cuda.cc | 74 +- .../cuda/codegen}/codegen_cuda.h | 2 +- .../cuda/codegen}/cuda_fallback_module.cc | 2 +- .../cuda/codegen}/cuda_fallback_module.h | 4 +- .../cuda/codegen}/intrin_rule_cuda.cc | 49 +- .../cuda/codegen}/literal/cuda_half_t.h | 0 .../cuda/codegen}/literal/cuda_int8_t.h | 0 .../cuda/codegen}/llvm/codegen_nvptx.cc | 14 +- .../cuda => backend/cuda/codegen}/ptx.cc | 28 +- .../cuda => backend/cuda/codegen}/ptx.h | 0 src/backend/cuda/codegen/register.cc | 143 + src/backend/cuda/op/register.cc | 35 + .../cuda/op/target_builtin.cc} | 69 +- .../cuda/runtime}/cuda_common.h | 2 +- .../cuda/runtime}/cuda_device_api.cc | 0 .../cuda/runtime}/cuda_module.cc | 8 +- .../cuda/runtime}/l2_cache_flush.cc | 2 +- .../cuda/runtime/vm}/cuda_graph_builtin.cc | 4 +- .../codegen}/hexagon_fallback_module.cc | 4 +- .../codegen}/hexagon_fallback_module.h | 4 +- .../hexagon/codegen}/llvm/codegen_hexagon.cc | 16 +- .../codegen}/llvm/intrin_rule_hexagon.cc | 35 +- src/backend/hexagon/codegen/register.cc | 64 + .../hexagon/runtime}/README.md | 0 .../hexagon/runtime}/hexagon_buffer.cc | 0 .../hexagon/runtime}/hexagon_buffer.h | 0 .../hexagon/runtime}/hexagon_buffer_manager.h | 0 .../hexagon/runtime}/hexagon_common.cc | 0 .../hexagon/runtime}/hexagon_common.h | 0 .../hexagon/runtime}/hexagon_device_api.cc | 2 +- .../hexagon/runtime}/hexagon_device_api.h | 0 .../hexagon/runtime}/hexagon_htp.cc | 0 .../hexagon/runtime}/hexagon_htp.h | 0 .../hexagon/runtime}/hexagon_hvx.cc | 0 .../hexagon/runtime}/hexagon_hvx.h | 0 .../hexagon/runtime}/hexagon_module.cc | 10 +- .../hexagon/runtime}/hexagon_power_manager.cc | 0 .../hexagon/runtime}/hexagon_power_manager.h | 0 .../runtime}/hexagon_thread_manager.cc | 0 .../hexagon/runtime}/hexagon_thread_manager.h | 0 .../hexagon/runtime}/hexagon_user_dma.cc | 0 .../hexagon/runtime}/hexagon_user_dma.h | 0 .../runtime}/hexagon_user_dma_descriptors.h | 0 .../runtime}/hexagon_user_dma_instructions.h | 0 .../runtime}/hexagon_user_dma_registers.h | 0 .../hexagon/runtime}/hexagon_vtcm_pool.cc | 0 .../hexagon/runtime}/hexagon_vtcm_pool.h | 0 .../hexagon/runtime}/ops/conv2d.h | 0 .../hexagon/runtime}/ops/conv2d_fp16_hvx.cc | 0 .../hexagon/runtime}/ops/conv2d_quant_hvx.cc | 0 .../hexagon/runtime}/ops/conv_utils.cc | 0 .../hexagon/runtime}/profiler/README.md | 0 .../hexagon/runtime}/profiler/lwp_handler.S | 0 .../hexagon/runtime}/profiler/prof_utils.cc | 0 .../hexagon/runtime}/profiler/prof_utils.h | 0 .../hexagon/runtime}/qhl/qhl_wrapper.cc | 0 .../hexagon/runtime}/ring_buffer.h | 0 .../hexagon/runtime}/rpc/android/session.cc | 6 +- .../runtime}/rpc/android_bash.sh.template | 0 .../runtime}/rpc/hexagon/rpc_server.cc | 6 +- .../hexagon/runtime}/rpc/hexagon_rpc.idl | 0 .../rpc/simulator/hexagon_sim_proto.h | 0 .../runtime}/rpc/simulator/rpc_server.cc | 2 +- .../hexagon/runtime}/rpc/simulator/session.cc | 6 +- .../metal/codegen}/codegen_metal.cc | 24 +- .../metal/codegen}/codegen_metal.h | 2 +- .../metal/codegen}/intrin_rule_metal.cc | 22 +- .../metal/codegen}/metal_fallback_module.cc | 2 +- .../metal/codegen}/metal_fallback_module.h | 4 +- src/backend/metal/codegen/register.cc | 63 + src/backend/metal/op/register.cc | 35 + src/backend/metal/op/target_builtin.cc | 62 + .../metal/runtime}/metal_common.h | 2 +- .../metal/runtime}/metal_device_api.mm | 0 .../metal/runtime}/metal_module.mm | 10 +- .../opencl/codegen}/codegen_opencl.cc | 18 +- .../opencl/codegen}/codegen_opencl.h | 2 +- .../opencl/codegen}/intrin_rule_opencl.cc | 38 +- .../opencl/codegen}/opencl_fallback_module.cc | 4 +- .../opencl/codegen}/opencl_fallback_module.h | 4 +- src/backend/opencl/codegen/register.cc | 67 + .../opencl/runtime}/opencl_common.h | 8 +- .../opencl/runtime}/opencl_device_api.cc | 4 +- .../opencl/runtime}/opencl_module.cc | 13 +- .../opencl/runtime}/opencl_wrapper/README.md | 0 .../runtime}/opencl_wrapper/opencl_wrapper.cc | 0 .../opencl/runtime}/source_utils.h | 0 .../opencl/runtime}/texture.h | 0 .../rocm/codegen}/llvm/codegen_amdgpu.cc | 14 +- .../rocm/codegen}/llvm/intrin_rule_rocm.cc | 8 +- src/backend/rocm/codegen/register.cc | 148 + .../rocm/codegen}/rocm_fallback_module.cc | 2 +- .../rocm/codegen}/rocm_fallback_module.h | 4 +- .../rocm/runtime}/rocm_common.h | 2 +- .../rocm/runtime}/rocm_device_api.cc | 0 .../rocm/runtime}/rocm_module.cc | 8 +- .../trn/codegen}/codegen_trn.cc | 59 +- .../trn/codegen}/codegen_trn.h | 2 +- src/backend/trn/codegen/register.cc | 61 + src/backend/trn/op/register.cc | 35 + .../trn/op/target_builtin.cc} | 34 +- .../trn/transform/lower_trainium_layout.cc | 361 ++ .../vulkan/codegen}/build_vulkan.cc | 14 +- .../vulkan/codegen}/codegen_spirv.cc | 21 +- .../vulkan/codegen}/codegen_spirv.h | 4 +- .../vulkan/codegen}/intrin_rule_spirv.cc | 26 +- .../vulkan/codegen}/ir_builder.cc | 0 .../vulkan/codegen}/ir_builder.h | 0 src/backend/vulkan/codegen/register.cc | 93 + .../vulkan/codegen}/spirv_support.cc | 0 .../vulkan/codegen}/spirv_support.h | 0 .../vulkan/codegen}/spirv_utils.cc | 4 +- .../vulkan/codegen}/spirv_utils.h | 2 +- .../vulkan/codegen}/vulkan_fallback_module.cc | 8 +- .../vulkan/codegen}/vulkan_fallback_module.h | 4 +- .../vulkan/runtime}/README.md | 0 .../vulkan/runtime}/spirv_shader.h | 0 .../vulkan/runtime}/thread_map.h | 0 .../vulkan/runtime}/vulkan_amdrgp.cc | 0 .../vulkan/runtime}/vulkan_amdrgp.h | 0 .../vulkan/runtime}/vulkan_buffer.cc | 0 .../vulkan/runtime}/vulkan_buffer.h | 0 .../vulkan/runtime}/vulkan_common.cc | 0 .../vulkan/runtime}/vulkan_common.h | 0 .../vulkan/runtime}/vulkan_device.cc | 2 +- .../vulkan/runtime}/vulkan_device.h | 0 .../vulkan/runtime}/vulkan_device_api.cc | 0 .../vulkan/runtime}/vulkan_device_api.h | 2 +- .../vulkan/runtime}/vulkan_instance.cc | 2 +- .../vulkan/runtime}/vulkan_instance.h | 0 .../vulkan/runtime}/vulkan_module.cc | 8 +- .../vulkan/runtime}/vulkan_stream.cc | 2 +- .../vulkan/runtime}/vulkan_stream.h | 0 .../vulkan/runtime}/vulkan_wrapped_func.cc | 4 +- .../vulkan/runtime}/vulkan_wrapped_func.h | 6 +- .../webgpu/codegen}/codegen_webgpu.cc | 20 +- .../webgpu/codegen}/codegen_webgpu.h | 2 +- .../webgpu/codegen}/intrin_rule_webgpu.cc | 20 +- src/backend/webgpu/codegen/register.cc | 80 + .../webgpu/codegen}/webgpu_fallback_module.cc | 2 +- .../webgpu/codegen}/webgpu_fallback_module.h | 4 +- .../transform/static_plan_block_memory.cc | 2 +- src/runtime/extra/contrib/clml/clml_runtime.h | 2 +- .../contrib/cublas/cublas_json_runtime.cc | 2 +- .../extra/contrib/cublas/cublas_utils.cc | 2 +- .../contrib/cudnn/cudnn_frontend/attention.cc | 2 +- src/runtime/extra/contrib/cudnn/cudnn_utils.h | 2 +- src/runtime/extra/contrib/curand/curand.cc | 2 +- .../cutlass/fp16_group_gemm_runner_sm100.cuh | 2 +- .../cutlass/fp16_group_gemm_runner_sm90.cuh | 2 +- ...fp8_groupwise_scaled_gemm_runner_sm100.cuh | 2 +- .../fp8_groupwise_scaled_gemm_runner_sm90.cuh | 2 +- ...oupwise_scaled_group_gemm_runner_sm100.cuh | 2 +- .../extra/contrib/cutlass/gemm_runner.cuh | 2 +- .../contrib/hipblas/hipblas_json_runtime.cc | 2 +- .../extra/contrib/hipblas/hipblas_utils.cc | 2 +- .../extra/contrib/nvshmem/dist_gemm.cu | 2 +- src/runtime/extra/contrib/nvshmem/init.cc | 2 +- .../extra/contrib/nvshmem/memory_allocator.cc | 2 +- .../contrib/tensorrt/tensorrt_calibrator.h | 2 +- src/runtime/extra/contrib/thrust/thrust.cu | 2 +- .../extra/disco/cuda_ipc/cuda_ipc_memory.cc | 2 +- src/runtime/extra/disco/nccl/nccl_context.h | 4 +- src/runtime/vm/attn_utils.h | 2 +- src/runtime/vm/hexagon/builtin.cc | 2 +- .../backend/adreno/inject_texture_alloc.cc | 2 +- src/s_tir/backend/adreno/texture_flatten.cc | 2 +- .../postproc/rewrite_cooperative_fetch.cc | 7 +- src/s_tir/schedule/analysis/analysis.cc | 5 +- src/s_tir/transform/inject_permuted_layout.cc | 29 +- src/s_tir/transform/inject_ptx_async_copy.cc | 10 +- src/s_tir/transform/inject_ptx_ldg32.cc | 4 +- .../transform/inject_software_pipeline.cc | 24 +- .../transform/memhammer_lower_auto_copy.cc | 6 +- .../transform/memhammer_tensorcore_rewrite.cc | 7 +- .../merge_shared_memory_allocations.cc | 4 +- .../transform/tensorcore_infer_fragment.cc | 13 +- src/target/llvm/intrin_rule_nvptx.cc | 3 +- src/target/tag.cc | 13 - src/target/target_kind.cc | 318 - src/tirx/analysis/filter_canonical.cc | 5 +- src/tirx/ir/data_type_rewriter.cc | 9 +- src/tirx/ir/exec_scope.cc | 4 +- src/tirx/ir/stmt.cc | 5 +- src/tirx/op/builtin.cc | 153 - src/tirx/op/op.cc | 47 +- src/tirx/script/builder/frame.cc | 3 +- src/tirx/transform/lower_intrin.cc | 4 +- src/tirx/transform/lower_tvm_builtin.cc | 24 +- src/tirx/transform/lower_warp_memory.cc | 51 +- src/tirx/transform/remove_no_op.cc | 11 +- src/tirx/transform/tile_primitive_dispatch.cc | 5 +- .../s_tir/dlight/test_gpu_matmul_tensorize.py | 20 +- tests/python/tirx-base/test_tir_op_types.py | 15 +- .../tile_primitive/cuda/copy/test_fallback.py | 4 +- .../cuda/copy/test_gmem_smem.py | 2 +- .../cuda/copy/test_swizzle_iter.py | 10 +- .../cuda/copy_async/test_dsmem.py | 2 +- .../cuda/copy_async/test_smem_tmem.py | 2 +- .../cuda/copy_async/test_tma.py | 8 +- .../cuda/elementwise/test_unary.py | 10 +- .../cuda/gemm_async/test_gemm_async.py | 92 +- .../permute_layout/test_permute_layout.py | 4 +- .../tile_primitive/trn/test_compose_op_trn.py | 14 +- .../tile_primitive/trn/test_copy_trn.py | 8 +- .../tile_primitive/trn/test_gemm_trn.py | 4 +- .../trn/test_private_alloc_trn.py | 2 +- .../tile_primitive/trn/test_reduction_trn.py | 4 +- .../tile_primitive/trn/test_unary_trn.py | 2 +- tests/python/tirx/test_alloc_pool.py | 6 +- tests/python/tirx/test_layout.py | 10 +- .../python/tirx/test_op_namespace_cleanup.py | 121 + .../tirx/test_printer_tir_namespaces.py | 198 +- .../test_transform_naive_allocator.py | 2 +- 375 files changed, 10774 insertions(+), 10099 deletions(-) delete mode 100644 include/tvm/tirx/target_builtin/cuda.h delete mode 100644 include/tvm/tirx/target_builtin/trn.h create mode 100644 python/tvm/backend/__init__.py create mode 100644 python/tvm/backend/adreno/__init__.py rename python/tvm/{target/tag_registry/adreno.py => backend/adreno/target_tags.py} (97%) create mode 100644 python/tvm/backend/cuda/__init__.py create mode 100644 python/tvm/backend/cuda/lang/__init__.py create mode 100644 python/tvm/backend/cuda/lang/alloc_pool.py create mode 100644 python/tvm/backend/cuda/lang/pipeline.py create mode 100644 python/tvm/backend/cuda/lang/smem_desc.py create mode 100644 python/tvm/backend/cuda/lang/tile_scheduler.py create mode 100644 python/tvm/backend/cuda/lang/warp_role.py create mode 100644 python/tvm/backend/cuda/op.py rename python/tvm/{tirx/backend/adreno => backend/cuda/operator}/__init__.py (88%) rename python/tvm/{tirx/operator/intrinsics/cuda => backend/cuda/operator/intrinsics}/__init__.py (100%) rename python/tvm/{tirx => backend/cuda}/operator/intrinsics/_schema.py (97%) rename python/tvm/{tirx/operator/intrinsics/cuda => backend/cuda/operator/intrinsics}/cp_async.py (99%) rename python/tvm/{tirx/operator/intrinsics/cuda => backend/cuda/operator/intrinsics}/header.py (100%) rename python/tvm/{tirx/operator/intrinsics/cuda => backend/cuda/operator/intrinsics}/math.py (99%) rename python/tvm/{tirx/operator/intrinsics/cuda => backend/cuda/operator/intrinsics}/memory.py (99%) rename python/tvm/{tirx/operator/intrinsics/cuda => backend/cuda/operator/intrinsics}/misc.py (99%) rename python/tvm/{tirx/operator/intrinsics/cuda => backend/cuda/operator/intrinsics}/mma.py (99%) rename python/tvm/{tirx/operator/intrinsics/cuda => backend/cuda/operator/intrinsics}/nvshmem.py (99%) rename python/tvm/{tirx/operator/intrinsics/cuda => backend/cuda/operator/intrinsics}/registry.py (100%) rename python/tvm/{tirx/operator/intrinsics/cuda => backend/cuda/operator/intrinsics}/sync.py (99%) rename python/tvm/{tirx/operator/intrinsics/cuda => backend/cuda/operator/intrinsics}/tcgen05.py (99%) rename python/tvm/{tirx/operator/intrinsics/cuda => backend/cuda/operator/intrinsics}/types.py (100%) rename python/tvm/{tirx/operator/intrinsics/cuda => backend/cuda/operator/intrinsics}/utils.py (100%) rename python/tvm/{tirx/operator/intrinsics/cuda => backend/cuda/operator/intrinsics}/wgmma.py (99%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/__init__.py (89%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/common.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/copy/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/copy/_common.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/copy/_swizzle_iter.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/copy/fallback.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/copy/gmem_smem.py (99%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/copy/ld_stmatrix.py (99%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/copy/reg.py (99%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/copy/utils.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/copy_async/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/copy_async/dsmem.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/copy_async/ldgsts.py (99%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/copy_async/tcgen05_cp.py (99%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/copy_async/tcgen05_ldst.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/copy_async/tma.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/copy_async/utils.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/elementwise/__init__.py (95%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/elementwise/_common.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/elementwise/ops/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/elementwise/ops/binary.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/elementwise/ops/cast.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/elementwise/ops/fma.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/elementwise/ops/unary.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/elementwise/reg.py (99%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/elementwise/register.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/elementwise/smem.py (99%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/elementwise/vec_emit/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/elementwise/vec_emit/binary_f32x2.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/elementwise/vec_emit/cast_vec2.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/elementwise/vec_emit/fma_f32x2.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/exec_scope_utils.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/gemm/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/gemm/mma_m16n8k_.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/gemm_async/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/gemm_async/tcgen05.py (83%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/gemm_utils.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/layout_utils.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/permute_layout/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/permute_layout/warp_xor_swizzle.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/reduction/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/reduction/local.py (99%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/reduction/shared.py (99%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/reduction/sm100_packed.py (99%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/reduction/utils.py (97%) rename python/tvm/{tirx/operator/tile_primitive/cuda => backend/cuda/operator/tile_primitive}/tma_utils.py (100%) create mode 100644 python/tvm/backend/cuda/script.py rename python/tvm/{target/tag_registry/cuda.py => backend/cuda/target_tags.py} (99%) create mode 100644 python/tvm/backend/hexagon/__init__.py rename python/tvm/{target/tag_registry/hexagon.py => backend/hexagon/target_tags.py} (98%) create mode 100644 python/tvm/backend/metal/__init__.py create mode 100644 python/tvm/backend/metal/op.py create mode 100644 python/tvm/backend/metal/script.py rename python/tvm/{target/tag_registry/metal.py => backend/metal/target_tags.py} (94%) create mode 100644 python/tvm/backend/opencl/__init__.py create mode 100644 python/tvm/backend/rocm/__init__.py create mode 100644 python/tvm/backend/trn/__init__.py create mode 100644 python/tvm/backend/trn/layout.py create mode 100644 python/tvm/backend/trn/op.py create mode 100644 python/tvm/backend/trn/operator/__init__.py rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/binary/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/binary/default.py (97%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/binary/utils.py (97%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/common.py (100%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/compose_op/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/compose_op/binary_chain.py (100%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/compose_op/binary_reduce.py (100%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/compose_op/compose_op.py (100%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/compose_op/reduce_negate.py (100%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/compose_op/unary_reduce.py (100%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/compose_op/utils.py (95%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/copy/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/copy/default.py (96%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/dim_utils.py (100%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/gemm/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/gemm/default.py (97%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/instruction_generator.py (98%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/private_alloc.py (96%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/reduction/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/reduction/default.py (94%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/reduction/utils.py (95%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/select/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/select/default.py (97%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/unary/__init__.py (100%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/unary/default.py (96%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/unary/utils.py (96%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/unary/with_bias_scale.py (96%) rename python/tvm/{tirx/operator/tile_primitive/trn => backend/trn/operator/tile_primitive}/workspace_utils.py (100%) create mode 100644 python/tvm/backend/trn/pipeline.py create mode 100644 python/tvm/backend/trn/script.py create mode 100644 python/tvm/backend/trn/target_tags.py rename python/tvm/{tirx/transform/trn => backend/trn/transform}/__init__.py (68%) rename python/tvm/{tirx/transform/trn => backend/trn/transform}/naive_allocator.py (98%) rename python/tvm/{tirx/transform/trn => backend/trn/transform}/private_buffer_alloc.py (100%) create mode 100644 python/tvm/backend/vulkan/__init__.py create mode 100644 python/tvm/backend/webgpu/__init__.py rename src/{target/cuda => backend/cuda/codegen}/codegen_cuda.cc (96%) rename src/{target/cuda => backend/cuda/codegen}/codegen_cuda.h (99%) rename src/{target/cuda => backend/cuda/codegen}/cuda_fallback_module.cc (99%) rename src/{target/cuda => backend/cuda/codegen}/cuda_fallback_module.h (98%) rename src/{target/cuda => backend/cuda/codegen}/intrin_rule_cuda.cc (89%) rename src/{target/cuda => backend/cuda/codegen}/literal/cuda_half_t.h (100%) rename src/{target/cuda => backend/cuda/codegen}/literal/cuda_int8_t.h (100%) rename src/{target/cuda => backend/cuda/codegen}/llvm/codegen_nvptx.cc (97%) rename src/{target/cuda => backend/cuda/codegen}/ptx.cc (97%) rename src/{target/cuda => backend/cuda/codegen}/ptx.h (100%) create mode 100644 src/backend/cuda/codegen/register.cc create mode 100644 src/backend/cuda/op/register.cc rename src/{tirx/op/target_builtin/cuda.cc => backend/cuda/op/target_builtin.cc} (94%) rename src/{runtime/cuda => backend/cuda/runtime}/cuda_common.h (98%) rename src/{runtime/cuda => backend/cuda/runtime}/cuda_device_api.cc (100%) rename src/{runtime/cuda => backend/cuda/runtime}/cuda_module.cc (98%) rename src/{runtime/cuda => backend/cuda/runtime}/l2_cache_flush.cc (96%) rename src/{runtime/vm/cuda => backend/cuda/runtime/vm}/cuda_graph_builtin.cc (99%) rename src/{target/hexagon => backend/hexagon/codegen}/hexagon_fallback_module.cc (97%) rename src/{target/hexagon => backend/hexagon/codegen}/hexagon_fallback_module.h (98%) rename src/{target/hexagon => backend/hexagon/codegen}/llvm/codegen_hexagon.cc (98%) rename src/{target/hexagon => backend/hexagon/codegen}/llvm/intrin_rule_hexagon.cc (89%) create mode 100644 src/backend/hexagon/codegen/register.cc rename src/{runtime/hexagon => backend/hexagon/runtime}/README.md (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_buffer.cc (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_buffer.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_buffer_manager.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_common.cc (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_common.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_device_api.cc (99%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_device_api.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_htp.cc (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_htp.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_hvx.cc (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_hvx.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_module.cc (94%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_power_manager.cc (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_power_manager.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_thread_manager.cc (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_thread_manager.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_user_dma.cc (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_user_dma.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_user_dma_descriptors.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_user_dma_instructions.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_user_dma_registers.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_vtcm_pool.cc (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/hexagon_vtcm_pool.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/ops/conv2d.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/ops/conv2d_fp16_hvx.cc (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/ops/conv2d_quant_hvx.cc (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/ops/conv_utils.cc (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/profiler/README.md (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/profiler/lwp_handler.S (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/profiler/prof_utils.cc (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/profiler/prof_utils.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/qhl/qhl_wrapper.cc (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/ring_buffer.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/rpc/android/session.cc (96%) rename src/{runtime/hexagon => backend/hexagon/runtime}/rpc/android_bash.sh.template (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/rpc/hexagon/rpc_server.cc (98%) rename src/{runtime/hexagon => backend/hexagon/runtime}/rpc/hexagon_rpc.idl (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/rpc/simulator/hexagon_sim_proto.h (100%) rename src/{runtime/hexagon => backend/hexagon/runtime}/rpc/simulator/rpc_server.cc (99%) rename src/{runtime/hexagon => backend/hexagon/runtime}/rpc/simulator/session.cc (99%) rename src/{target/metal => backend/metal/codegen}/codegen_metal.cc (95%) rename src/{target/metal => backend/metal/codegen}/codegen_metal.h (98%) rename src/{target/metal => backend/metal/codegen}/intrin_rule_metal.cc (91%) rename src/{target/metal => backend/metal/codegen}/metal_fallback_module.cc (99%) rename src/{target/metal => backend/metal/codegen}/metal_fallback_module.h (98%) create mode 100644 src/backend/metal/codegen/register.cc create mode 100644 src/backend/metal/op/register.cc create mode 100644 src/backend/metal/op/target_builtin.cc rename src/{runtime/metal => backend/metal/runtime}/metal_common.h (99%) rename src/{runtime/metal => backend/metal/runtime}/metal_device_api.mm (100%) rename src/{runtime/metal => backend/metal/runtime}/metal_module.mm (98%) rename src/{target/opencl => backend/opencl/codegen}/codegen_opencl.cc (98%) rename src/{target/opencl => backend/opencl/codegen}/codegen_opencl.h (98%) rename src/{target/opencl => backend/opencl/codegen}/intrin_rule_opencl.cc (93%) rename src/{target/opencl => backend/opencl/codegen}/opencl_fallback_module.cc (97%) rename src/{target/opencl => backend/opencl/codegen}/opencl_fallback_module.h (98%) create mode 100644 src/backend/opencl/codegen/register.cc rename src/{runtime/opencl => backend/opencl/runtime}/opencl_common.h (99%) rename src/{runtime/opencl => backend/opencl/runtime}/opencl_device_api.cc (99%) rename src/{runtime/opencl => backend/opencl/runtime}/opencl_module.cc (97%) rename src/{runtime/opencl => backend/opencl/runtime}/opencl_wrapper/README.md (100%) rename src/{runtime/opencl => backend/opencl/runtime}/opencl_wrapper/opencl_wrapper.cc (100%) rename src/{runtime/opencl => backend/opencl/runtime}/source_utils.h (100%) rename src/{runtime/opencl => backend/opencl/runtime}/texture.h (100%) rename src/{target/rocm => backend/rocm/codegen}/llvm/codegen_amdgpu.cc (97%) rename src/{target/rocm => backend/rocm/codegen}/llvm/intrin_rule_rocm.cc (97%) create mode 100644 src/backend/rocm/codegen/register.cc rename src/{target/rocm => backend/rocm/codegen}/rocm_fallback_module.cc (99%) rename src/{target/rocm => backend/rocm/codegen}/rocm_fallback_module.h (98%) rename src/{runtime/rocm => backend/rocm/runtime}/rocm_common.h (98%) rename src/{runtime/rocm => backend/rocm/runtime}/rocm_device_api.cc (100%) rename src/{runtime/rocm => backend/rocm/runtime}/rocm_module.cc (98%) rename src/{target/source => backend/trn/codegen}/codegen_trn.cc (91%) rename src/{target/source => backend/trn/codegen}/codegen_trn.h (98%) create mode 100644 src/backend/trn/codegen/register.cc create mode 100644 src/backend/trn/op/register.cc rename src/{tirx/op/target_builtin/trn.cc => backend/trn/op/target_builtin.cc} (86%) create mode 100644 src/backend/trn/transform/lower_trainium_layout.cc rename src/{target/vulkan => backend/vulkan/codegen}/build_vulkan.cc (88%) rename src/{target/vulkan => backend/vulkan/codegen}/codegen_spirv.cc (98%) rename src/{target/vulkan => backend/vulkan/codegen}/codegen_spirv.h (98%) rename src/{target/vulkan => backend/vulkan/codegen}/intrin_rule_spirv.cc (92%) rename src/{target/vulkan => backend/vulkan/codegen}/ir_builder.cc (100%) rename src/{target/vulkan => backend/vulkan/codegen}/ir_builder.h (100%) create mode 100644 src/backend/vulkan/codegen/register.cc rename src/{target/vulkan => backend/vulkan/codegen}/spirv_support.cc (100%) rename src/{target/vulkan => backend/vulkan/codegen}/spirv_support.h (100%) rename src/{target/vulkan => backend/vulkan/codegen}/spirv_utils.cc (98%) rename src/{target/vulkan => backend/vulkan/codegen}/spirv_utils.h (97%) rename src/{target/vulkan => backend/vulkan/codegen}/vulkan_fallback_module.cc (95%) rename src/{target/vulkan => backend/vulkan/codegen}/vulkan_fallback_module.h (98%) rename src/{runtime/vulkan => backend/vulkan/runtime}/README.md (100%) rename src/{runtime/vulkan => backend/vulkan/runtime}/spirv_shader.h (100%) rename src/{runtime/vulkan => backend/vulkan/runtime}/thread_map.h (100%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_amdrgp.cc (100%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_amdrgp.h (100%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_buffer.cc (100%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_buffer.h (100%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_common.cc (100%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_common.h (100%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_device.cc (99%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_device.h (100%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_device_api.cc (100%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_device_api.h (99%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_instance.cc (99%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_instance.h (100%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_module.cc (92%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_stream.cc (99%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_stream.h (100%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_wrapped_func.cc (99%) rename src/{runtime/vulkan => backend/vulkan/runtime}/vulkan_wrapped_func.h (97%) rename src/{target/webgpu => backend/webgpu/codegen}/codegen_webgpu.cc (98%) rename src/{target/webgpu => backend/webgpu/codegen}/codegen_webgpu.h (98%) rename src/{target/webgpu => backend/webgpu/codegen}/intrin_rule_webgpu.cc (92%) create mode 100644 src/backend/webgpu/codegen/register.cc rename src/{target/webgpu => backend/webgpu/codegen}/webgpu_fallback_module.cc (99%) rename src/{target/webgpu => backend/webgpu/codegen}/webgpu_fallback_module.h (98%) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e25f10e7f13..ad99c4c6acba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -339,13 +339,19 @@ tvm_file_glob(GLOB CODEGEN_SRCS src/target/source/*.cc src/target/canonicalizer/*.cc src/target/canonicalizer/llvm/*.cc - src/target/cuda/*.cc - src/target/rocm/*.cc - src/target/hexagon/*.cc - src/target/metal/*.cc - src/target/opencl/*.cc - src/target/vulkan/vulkan_fallback_module.cc - src/target/webgpu/*.cc + src/backend/cuda/codegen/*.cc + src/backend/cuda/op/*.cc + src/backend/hexagon/codegen/*.cc + src/backend/metal/codegen/*.cc + src/backend/metal/op/*.cc + src/backend/opencl/codegen/*.cc + src/backend/rocm/codegen/*.cc + src/backend/trn/codegen/*.cc + src/backend/trn/op/*.cc + src/backend/trn/transform/*.cc + src/backend/vulkan/codegen/register.cc + src/backend/vulkan/codegen/vulkan_fallback_module.cc + src/backend/webgpu/codegen/*.cc ) list(APPEND COMPILER_SRCS ${CODEGEN_SRCS}) @@ -361,7 +367,7 @@ tvm_file_glob(GLOB RUNTIME_SRCS src/runtime/rpc/minrpc/*.cc ) # Note: src/runtime/extra/disco/** moves to libtvm_runtime_extra. -# Note: src/runtime/{cuda,vulkan,opencl,metal,rocm,hexagon}/* move to per-backend DSOs. +# Note: src/backend/*/runtime sources move to per-backend DSOs. set(TVM_RUNTIME_EXT_OBJS "") if(BUILD_FOR_HEXAGON) diff --git a/cmake/modules/CUDA.cmake b/cmake/modules/CUDA.cmake index 0028e04fcc5d..53f4bf8d868c 100644 --- a/cmake/modules/CUDA.cmake +++ b/cmake/modules/CUDA.cmake @@ -62,8 +62,8 @@ endif(USE_CUDA) if(USE_CUDA) message(STATUS "Build cuda device runtime") - tvm_file_glob(GLOB RUNTIME_CUDA_SRCS src/runtime/cuda/*.cc) - tvm_file_glob(GLOB VM_CUDA_BUILTIN_SRC_CC src/runtime/vm/cuda/*.cc) + tvm_file_glob(GLOB RUNTIME_CUDA_SRCS src/backend/cuda/runtime/*.cc) + tvm_file_glob(GLOB VM_CUDA_BUILTIN_SRC_CC src/backend/cuda/runtime/vm/*.cc) add_library(tvm_runtime_cuda_objs OBJECT ${RUNTIME_CUDA_SRCS} ${VM_CUDA_BUILTIN_SRC_CC}) target_link_libraries(tvm_runtime_cuda_objs PUBLIC tvm_ffi_header) diff --git a/cmake/modules/Hexagon.cmake b/cmake/modules/Hexagon.cmake index c92fc7079949..512a893c73ad 100644 --- a/cmake/modules/Hexagon.cmake +++ b/cmake/modules/Hexagon.cmake @@ -50,7 +50,8 @@ macro(file_glob_append _output_list) set(${_output_list} ${_tmp1}) endmacro() -set(TVMRT_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/src/runtime") +set(TVMRT_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/src/backend/hexagon/runtime") +set(TVM_CORE_RUNTIME_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}/src/runtime") if(DEFINED USE_HEXAGON_DEVICE) message(WARNING "USE_HEXAGON_DEVICE is deprecated, use USE_HEXAGON instead") @@ -79,7 +80,7 @@ endif() if(NOT USE_HEXAGON) # USE_HEXAGON=OFF: codegen still works through the per-backend fallback - # module (src/target/hexagon/hexagon_fallback_module.cc), which is always + # module (src/backend/hexagon/codegen/hexagon_fallback_module.cc), which is always # compiled into libtvm via CODEGEN_SRCS. No opt-stub registration is # needed. return() @@ -122,14 +123,14 @@ if(BUILD_FOR_HEXAGON) # When building FOR Hexagon (the DSP itself), all runtime sources go into # the single libtvm_runtime (static or shared). No per-backend DSO split. file_glob_append(RUNTIME_HEXAGON_SRCS - "${TVMRT_SOURCE_DIR}/hexagon/*.cc" + "${TVMRT_SOURCE_DIR}/*.cc" ) # Add builtins to RelaxVM tvm_file_glob(GLOB VM_BUILTIN_SRC_CC src/runtime/vm/hexagon/*.cc) list(APPEND RUNTIME_SRCS ${VM_BUILTIN_SRC_CC}) else() file_glob_append(RUNTIME_HEXAGON_SRCS - "${TVMRT_SOURCE_DIR}/hexagon/hexagon_module.cc" + "${TVMRT_SOURCE_DIR}/hexagon_module.cc" ) endif() @@ -137,7 +138,7 @@ set(htp_supported_archs "v68" "v69" "v73" "v75") list(FIND htp_supported_archs "${USE_HEXAGON_ARCH}" supported_arch_index) if(${supported_arch_index} EQUAL -1) # Exclude User DMA files when building for archs below v68 - list(REMOVE_ITEM RUNTIME_HEXAGON_SRCS "${TVMRT_SOURCE_DIR}/hexagon/hexagon_user_dma.cc") + list(REMOVE_ITEM RUNTIME_HEXAGON_SRCS "${TVMRT_SOURCE_DIR}/hexagon_user_dma.cc") endif() if(BUILD_FOR_HEXAGON) @@ -160,7 +161,7 @@ if(BUILD_FOR_HEXAGON) # QHL support. if(USE_HEXAGON_QHL) file_glob_append(TVM_QHL_WRAPPER_SRCS - "${TVMRT_SOURCE_DIR}/hexagon/qhl/*.cc" + "${TVMRT_SOURCE_DIR}/qhl/*.cc" ) include_directories( @@ -181,20 +182,20 @@ if(BUILD_FOR_HEXAGON) if(${supported_arch_index} GREATER -1) # Hand-written ops file_glob_append(RUNTIME_HEXAGON_SRCS - "${TVMRT_SOURCE_DIR}/hexagon/ops/*.cc" + "${TVMRT_SOURCE_DIR}/ops/*.cc" ) include_directories( - "${TVMRT_SOURCE_DIR}/hexagon/ops" + "${TVMRT_SOURCE_DIR}/ops" ) set_source_files_properties( - "${TVMRT_SOURCE_DIR}/hexagon/ops/conv2d_quant_hvx.cc" + "${TVMRT_SOURCE_DIR}/ops/conv2d_quant_hvx.cc" PROPERTIES COMPILE_FLAGS "-mhvx" ) set_source_files_properties( - "${TVMRT_SOURCE_DIR}/hexagon/ops/conv2d_fp16_hvx.cc" + "${TVMRT_SOURCE_DIR}/ops/conv2d_fp16_hvx.cc" PROPERTIES COMPILE_FLAGS "-mhvx" ) endif() @@ -249,21 +250,21 @@ if(USE_HEXAGON_RPC) add_custom_command( OUTPUT - "${TVMRT_SOURCE_DIR}/hexagon/rpc/hexagon_rpc.h" - "${TVMRT_SOURCE_DIR}/hexagon/rpc/hexagon_rpc_skel.c" - "${TVMRT_SOURCE_DIR}/hexagon/rpc/hexagon_rpc_stub.c" + "${TVMRT_SOURCE_DIR}/rpc/hexagon_rpc.h" + "${TVMRT_SOURCE_DIR}/rpc/hexagon_rpc_skel.c" + "${TVMRT_SOURCE_DIR}/rpc/hexagon_rpc_stub.c" COMMAND ${QAIC_EXE_PATH} ${QAIC_FLAGS} - "${TVMRT_SOURCE_DIR}/hexagon/rpc/hexagon_rpc.idl" - -o "${TVMRT_SOURCE_DIR}/hexagon/rpc" - MAIN_DEPENDENCY "${TVMRT_SOURCE_DIR}/hexagon/rpc/hexagon_rpc.idl" + "${TVMRT_SOURCE_DIR}/rpc/hexagon_rpc.idl" + -o "${TVMRT_SOURCE_DIR}/rpc" + MAIN_DEPENDENCY "${TVMRT_SOURCE_DIR}/rpc/hexagon_rpc.idl" ) if("${CMAKE_C_COMPILER_ID}" STREQUAL "GNU" OR "${CMAKE_C_COMPILER_ID}" STREQUAL "Clang") # We can't easily fix this at the source-code level, because the .c file is generated # by the qaic program. But it should be safe to ignore the warning: # https://stackoverflow.com/questions/13905200/is-it-wise-to-ignore-gcc-clangs-wmissing-braces-warning - set_source_files_properties("${TVMRT_SOURCE_DIR}/hexagon/rpc/hexagon_rpc_stub.c" + set_source_files_properties("${TVMRT_SOURCE_DIR}/rpc/hexagon_rpc_stub.c" PROPERTY COMPILE_FLAGS "-Wno-missing-braces") endif() endfunction() @@ -273,12 +274,12 @@ if(USE_HEXAGON_RPC) add_android_paths() build_rpc_idl() file_glob_append(RUNTIME_HEXAGON_SRCS - "${TVMRT_SOURCE_DIR}/hexagon/rpc/android/*.cc" + "${TVMRT_SOURCE_DIR}/rpc/android/*.cc" ) # Add this file separately, because it's auto-generated, and glob won't # find it during cmake-time. list(APPEND RUNTIME_HEXAGON_SRCS - "${TVMRT_SOURCE_DIR}/hexagon/rpc/hexagon_rpc_stub.c" + "${TVMRT_SOURCE_DIR}/rpc/hexagon_rpc_stub.c" ) list(APPEND TVM_RUNTIME_LINKER_LIBS cdsprpc) @@ -289,30 +290,30 @@ if(USE_HEXAGON_RPC) # Include the generic RPC code into the TVM runtime. list(APPEND RUNTIME_HEXAGON_SRCS - "${TVMRT_SOURCE_DIR}/rpc/minrpc/minrpc_server.h" - "${TVMRT_SOURCE_DIR}/rpc/minrpc/rpc_reference.h" - "${TVMRT_SOURCE_DIR}/rpc/rpc_module.cc" - "${TVMRT_SOURCE_DIR}/rpc/rpc_endpoint.cc" - "${TVMRT_SOURCE_DIR}/rpc/rpc_session.cc" + "${TVM_CORE_RUNTIME_SOURCE_DIR}/rpc/minrpc/minrpc_server.h" + "${TVM_CORE_RUNTIME_SOURCE_DIR}/rpc/minrpc/rpc_reference.h" + "${TVM_CORE_RUNTIME_SOURCE_DIR}/rpc/rpc_module.cc" + "${TVM_CORE_RUNTIME_SOURCE_DIR}/rpc/rpc_endpoint.cc" + "${TVM_CORE_RUNTIME_SOURCE_DIR}/rpc/rpc_session.cc" # TODO(masahi): Remove rpc_local_session.cc after verifying that things work without it - "${TVMRT_SOURCE_DIR}/rpc/rpc_local_session.cc" + "${TVM_CORE_RUNTIME_SOURCE_DIR}/rpc/rpc_local_session.cc" ) - set(HEXAGON_PROFILER_DIR "${TVMRT_SOURCE_DIR}/hexagon/profiler") + set(HEXAGON_PROFILER_DIR "${TVMRT_SOURCE_DIR}/profiler") # Add the hardware-specific RPC code into the skel library. set_property(SOURCE ${HEXAGON_PROFILER_DIR}/lwp_handler.S PROPERTY LANGUAGE C) add_library(hexagon_rpc_skel SHARED - "${TVMRT_SOURCE_DIR}/hexagon/rpc/hexagon/rpc_server.cc" - "${TVMRT_SOURCE_DIR}/hexagon/rpc/hexagon_rpc_skel.c" + "${TVMRT_SOURCE_DIR}/rpc/hexagon/rpc_server.cc" + "${TVMRT_SOURCE_DIR}/rpc/hexagon_rpc_skel.c" "${HEXAGON_PROFILER_DIR}/prof_utils.cc" "${HEXAGON_PROFILER_DIR}/lwp_handler.S" ) target_include_directories(hexagon_rpc_skel - SYSTEM PRIVATE "${TVMRT_SOURCE_DIR}/hexagon/rpc" + SYSTEM PRIVATE "${TVMRT_SOURCE_DIR}/rpc" ) # Add the simulator-specific RPC code into a shared library to be # executed via run_main_on_sim. add_library(hexagon_rpc_sim SHARED - "${TVMRT_SOURCE_DIR}/hexagon/rpc/simulator/rpc_server.cc" + "${TVMRT_SOURCE_DIR}/rpc/simulator/rpc_server.cc" "${HEXAGON_PROFILER_DIR}/prof_utils.cc" "${HEXAGON_PROFILER_DIR}/lwp_handler.S" ) @@ -324,7 +325,7 @@ if(USE_HEXAGON_RPC) find_hexagon_toolchain() add_hexagon_wrapper_paths() file_glob_append(RUNTIME_HEXAGON_SRCS - "${TVMRT_SOURCE_DIR}/hexagon/rpc/simulator/session.cc" + "${TVMRT_SOURCE_DIR}/rpc/simulator/session.cc" ) list(APPEND TVM_RUNTIME_LINKER_LIBS "-lwrapper") endif() diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake index f944b4130415..8ad8bfc53c56 100644 --- a/cmake/modules/LLVM.cmake +++ b/cmake/modules/LLVM.cmake @@ -44,9 +44,9 @@ if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN}) add_definitions(-DTVM_LLVM_HAS_AARCH64_TARGET=${TVM_LLVM_HAS_AARCH64_TARGET}) tvm_file_glob(GLOB COMPILER_LLVM_SRCS src/target/llvm/*.cc - src/target/cuda/llvm/*.cc - src/target/rocm/llvm/*.cc - src/target/hexagon/llvm/*.cc + src/backend/cuda/codegen/llvm/*.cc + src/backend/rocm/codegen/llvm/*.cc + src/backend/hexagon/codegen/llvm/*.cc ) list(APPEND TVM_LINKER_LIBS ${LLVM_LIBS}) list(APPEND COMPILER_SRCS ${COMPILER_LLVM_SRCS}) diff --git a/cmake/modules/Metal.cmake b/cmake/modules/Metal.cmake index c593d0d420cb..02526b4c9c29 100644 --- a/cmake/modules/Metal.cmake +++ b/cmake/modules/Metal.cmake @@ -19,7 +19,7 @@ if(USE_METAL) message(STATUS "Build metal device runtime") find_library(METAL_LIB Metal) find_library(FOUNDATION_LIB Foundation) - tvm_file_glob(GLOB RUNTIME_METAL_SRCS src/runtime/metal/*.mm) + tvm_file_glob(GLOB RUNTIME_METAL_SRCS src/backend/metal/runtime/*.mm) add_library(tvm_runtime_metal_objs OBJECT ${RUNTIME_METAL_SRCS}) target_link_libraries(tvm_runtime_metal_objs PUBLIC tvm_ffi_header) @@ -33,5 +33,5 @@ if(USE_METAL) tvm_configure_target_library(tvm_runtime_metal RUNTIME_MODULE) endif(USE_METAL) # When USE_METAL=OFF the codegen-side fallback in -# src/target/metal/metal_fallback_module.cc handles construction; no opt +# src/backend/metal/codegen/metal_fallback_module.cc handles construction; no opt # stub is needed (it is always compiled via CODEGEN_SRCS in CMakeLists.txt). diff --git a/cmake/modules/OpenCL.cmake b/cmake/modules/OpenCL.cmake index f833832d4cde..392457da7aa4 100644 --- a/cmake/modules/OpenCL.cmake +++ b/cmake/modules/OpenCL.cmake @@ -16,13 +16,13 @@ # under the License. if(USE_OPENCL) - tvm_file_glob(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc) + tvm_file_glob(GLOB RUNTIME_OPENCL_SRCS src/backend/opencl/runtime/*.cc) set(_opencl_libs "") if(${USE_OPENCL} MATCHES ${IS_TRUE_PATTERN}) message(STATUS "Enabled runtime search for OpenCL library location") file_glob_append(RUNTIME_OPENCL_SRCS - "src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc" + "src/backend/opencl/runtime/opencl_wrapper/opencl_wrapper.cc" ) include_directories(SYSTEM "3rdparty/OpenCL-Headers") else() diff --git a/cmake/modules/ROCM.cmake b/cmake/modules/ROCM.cmake index a2d1516558ba..ce9ad1414f90 100644 --- a/cmake/modules/ROCM.cmake +++ b/cmake/modules/ROCM.cmake @@ -32,7 +32,7 @@ if(USE_ROCM) endif() message(STATUS "Build rocm device runtime") - tvm_file_glob(GLOB RUNTIME_ROCM_SRCS src/runtime/rocm/*.cc) + tvm_file_glob(GLOB RUNTIME_ROCM_SRCS src/backend/rocm/runtime/*.cc) set(_rocm_libs ${ROCM_HIPHCC_LIBRARY}) if(ROCM_HSA_LIBRARY) diff --git a/cmake/modules/Vulkan.cmake b/cmake/modules/Vulkan.cmake index ba51e4b84206..2f34b6324055 100644 --- a/cmake/modules/Vulkan.cmake +++ b/cmake/modules/Vulkan.cmake @@ -29,12 +29,12 @@ if(USE_VULKAN) include_directories(SYSTEM ${Vulkan_INCLUDE_DIRS}) message(STATUS "Build with Vulkan support") tvm_file_glob(GLOB COMPILER_VULKAN_SRCS - src/target/vulkan/build_vulkan.cc - src/target/vulkan/codegen_spirv.cc - src/target/vulkan/intrin_rule_spirv.cc - src/target/vulkan/ir_builder.cc - src/target/vulkan/spirv_support.cc - src/target/vulkan/spirv_utils.cc + src/backend/vulkan/codegen/build_vulkan.cc + src/backend/vulkan/codegen/codegen_spirv.cc + src/backend/vulkan/codegen/intrin_rule_spirv.cc + src/backend/vulkan/codegen/ir_builder.cc + src/backend/vulkan/codegen/spirv_support.cc + src/backend/vulkan/codegen/spirv_utils.cc ) list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS}) list(APPEND TVM_LINKER_LIBS ${Vulkan_SPIRV_TOOLS_LIBRARY}) @@ -44,7 +44,7 @@ endif(USE_VULKAN) if(USE_VULKAN) message(STATUS "Build vulkan device runtime") - tvm_file_glob(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/*.cc) + tvm_file_glob(GLOB RUNTIME_VULKAN_SRCS src/backend/vulkan/runtime/*.cc) add_library(tvm_runtime_vulkan_objs OBJECT ${RUNTIME_VULKAN_SRCS}) target_link_libraries(tvm_runtime_vulkan_objs PUBLIC tvm_ffi_header) diff --git a/docs/reference/api/python/tirx/backend.rst b/docs/reference/api/python/tirx/backend.rst index c4e6689bfe21..35d643380241 100644 --- a/docs/reference/api/python/tirx/backend.rst +++ b/docs/reference/api/python/tirx/backend.rst @@ -23,9 +23,3 @@ tvm.tirx.backend .. automodule:: tvm.tirx.backend :members: :imported-members: - -tvm.tirx.backend.adreno -*********************** -.. automodule:: tvm.tirx.backend.adreno - :members: - :imported-members: diff --git a/include/tvm/tirx/builtin.h b/include/tvm/tirx/builtin.h index f25e48daf330..953ee8daa1d7 100644 --- a/include/tvm/tirx/builtin.h +++ b/include/tvm/tirx/builtin.h @@ -560,49 +560,6 @@ TVM_DLL const Op& tvm_global_barrier_kinit(); * } */ TVM_DLL const Op& tvm_thread_allreduce(); - -// Metal SimdGroup matrix intrinsics - -/*! - * \brief tvm intrinsic for initializing and simdgroup with given value. - * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params, - * keeping the similar interface with Metal Spec. - * - * void make_filled_simdgroup_matrix(Var d, PrimExpr index, PrimExpr value, - * int col = 8, int row = 8); - */ -TVM_DLL const Op& make_filled_simdgroup_matrix(); - -/*! - * \brief tvm intrinsic for loading data from device memory or threadgroup memory to simdgroup. - * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params, - * keeping the similar interface with Metal Spec. - * - * void simdgroup_load(Var d, PrimExpr index, PrimExpr ptr, PrimExpr stride, - int col = 8, int row = 8, bool transpose_matrix = false); - */ -TVM_DLL const Op& simdgroup_load(); - -/*! - * \brief tvm intrinsic for storing data from simdgroup to device memory or threadgroup memory. - * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params, - * keeping the similar interface with Metal Spec. - * - * void simdgroup_store(Var d, PrimExpr index, PrimExpr ptr, PrimExpr stride, - * int col = 8, int row = 8, bool transpose_matrix = false); - */ -TVM_DLL const Op& simdgroup_store(); - -/*! - * \brief tvm intrinsic for multiply and accumulate two matrices in simdgroup - * \note only 8x8 shape is supported by Metal Spec and TVM, but we still keep shape as params, - * keeping the similar interface with Metal Spec. - * - * void simdgroup_mma(Var d, PrimExpr index_d, Var a, PrimExpr index_a, - * Var b, PrimExpr index_b, Var c, PrimExpr index_c); - */ -TVM_DLL const Op& simdgroup_multiply_accumulate(); - // Metal cooperative_tensor intrinsics (MetalPerformancePrimitives / Metal 4) /*! @@ -856,233 +813,6 @@ enum TVMStructFieldKind : int { * \brief Print the content of a buffer during runtime. */ TVM_DLL const Op& print_buffer(); - -/*! - * \brief tvm intrinsic for initializing the CUDA profiler, and store profiling result in a buffer. - * - * void timer_init_cuda(Var profiler_buffer, Var profiler_tag, Var profiler_write_offset, int - * num_groups, Expr group_id) { - * // initialize the tag and write to pos 0 in the buffer - * // initialize write offset for every leader thread in warp group across all blocks - * } - */ -TVM_DLL const Op& timer_init_cuda(); - -/*! - * \brief tvm intrinsic for starting the timer for profiling a specific event, - * and storing profiling result in a buffer. - * - * void timer_start_cuda(IntImm event_type, Var profiler_buffer, Var profiler_tag, - * Var profiler_write_offset, IntImm profiler_write_stride, Expr leader_cond) - * { - * // each leader thread in warp group gets the time stamp and event type, combine with the tag - * // and write to corresponding offset in buffer - * // each leader thread advance offset by stride - * } - */ -TVM_DLL const Op& timer_start_cuda(); - -/*! - * \brief tvm intrinsic for ending the timer for profiling a specific event, - * and storing profiling result in a buffer. - * - * void timer_end_cuda(IntImm event_type, Var profiler_buffer, Var profiler_tag, - * Var profiler_write_offset, IntImm profiler_write_stride, Expr leader_cond) { - * // each leader thread in warp group gets the time stamp and event type, combine with the tag - * // and write to corresponding offset in buffer - * // each leader thread advance offset by stride - * } - */ -TVM_DLL const Op& timer_end_cuda(); - -/*! - * \brief tvm intrinsic for finalize the timer for profiling, - * and storing profiling result in a buffer. - * - * void timer_finalize_cuda(Var profiler_buffer, Var profiler_tag, Var profiler_write_offset, - * IntImm profiler_write_stride, Expr leader_cond) { - * // each leader thread in warp group gets the time stamp and end signal, combine with the tag - * // and write to corresponding offset in buffer - * // each leader thread advance offset by stride - * } - */ -TVM_DLL const Op& timer_finalize_cuda(); - -/*! - * \brief tvm intrinsic for cuda atomic add instruction - */ -TVM_DLL const Op& cuda_atomic_add(); - -/*! - * \brief tvm intrinsic for cuda thread fence instruction - */ -TVM_DLL const Op& cuda_thread_fence(); - -/*! - * \brief tvm intrinsic for cuda warpgroup sync instruction - */ -TVM_DLL const Op& cuda_warpgroup_sync(); - -/*! - * \brief Warp-level butterfly shuffle-XOR reduction. - * - * cuda_warp_reduce(value, op, width) reduces value across width adjacent - * lanes using the specified operation ("sum", "max", "min"). - */ -TVM_DLL const Op& cuda_warp_reduce(); - -/*! - * \brief CTA-wide reduction via warp shuffle + shared memory. - * - * cuda_cta_reduce(value, op, num_warps, scratch) reduces value across - * the entire CTA using the specified operation ("sum", "max", "min"). - */ -TVM_DLL const Op& cuda_cta_reduce(); - -/*! - * \brief Typed load/store copy of num_bytes bytes. - * - * cuda_copy_bytes(dst, src, num_bytes) copies num_bytes bytes from src to dst - * using a single typed load/store (uint4, uint2, unsigned int, etc.). - * num_bytes must be one of {1, 2, 4, 8, 16}. - */ -TVM_DLL const Op& cuda_copy_bytes(); - -/*! - * \brief tvm intrinsic for cuda warp sync instruction - */ -TVM_DLL const Op& cuda_warp_sync(); - -/*! - * \brief tvm intrinsic for cuda block-wide sync (syncthreads) - */ -TVM_DLL const Op& cuda_cta_sync(); - -/*! - * \brief tvm intrinsic for cuda grid-wide sync (cooperative groups) - */ -TVM_DLL const Op& cuda_grid_sync(); - -/*! - * \brief tvm intrinsic for cuda cluster-wide sync instruction - */ -TVM_DLL const Op& cuda_cluster_sync(); - -/*! - * \brief tvm intrinsic that returns ``cooperative_groups::thread_rank()`` - * for the enclosing CTA (linear thread index within the block). - */ -TVM_DLL const Op& cuda_thread_rank(); - -/*! - * \brief tvm intrinsic for cuda half to float conversion - */ -TVM_DLL const Op& cuda_half2float(); - -/*! - * \brief tvm intrinsic for cuda bfloat16 to float conversion - */ -TVM_DLL const Op& cuda_bfloat162float(); - -/*! - * \brief tvm intrinsic for a helper converting float2 to half2 with rounding - */ -TVM_DLL const Op& cuda_float22half2(); - -/*! - * \brief tvm intrinsic to trap when an assertion failed (cond == false) - */ -TVM_DLL const Op& cuda_trap_when_assert_failed(); - -/*! - * \brief tvm intrinsic to modify runtime instruction descriptor - */ -TVM_DLL const Op& cuda_runtime_instr_desc(); - -/*! - * \brief tvm intrinsic to convert 8 half2 lanes to 8 float2 lanes - */ -TVM_DLL const Op& cuda_half8tofloat8(); - -/*! - * \brief tvm intrinsic to convert 8 float2 lanes to 8 half2 lanes with rounding - */ -TVM_DLL const Op& cuda_float8tohalf8(); - -/*! - * \brief tvm intrinsic for cuda syncthreads_and instruction - */ -TVM_DLL const Op& cuda_syncthreads_and(); - -/*! - * \brief tvm intrinsic for cuda syncthreads_or instruction - */ -TVM_DLL const Op& cuda_syncthreads_or(); - -/*! - * \brief tvm intrinsic for cuda nano sleep instruction - */ -TVM_DLL const Op& cuda_nano_sleep(); - -/*! - * \brief tvm intrinsic for cuda atomic compare and swap instruction - */ -TVM_DLL const Op& cuda_atomic_cas(); - -/*! - * \brief tvm intrinsic for cuda printf instruction - */ -TVM_DLL const Op& cuda_printf(); - -/*! - * \brief tvm intrinsic for cuda ldg instruction - */ -TVM_DLL const Op& cuda_ldg(); - -/*! - * \brief tvm intrinsic for cuda tmem address calculation - */ -TVM_DLL const Op& cuda_get_tmem_addr(); - -/*! - * \brief tvm intrinsic for PTX fast exp2 approximation (ex2.approx.ftz.f32) - */ -TVM_DLL const Op& ptx_exp2(); - -/*! - * \brief tvm intrinsic for PTX fast reciprocal approximation (rcp.approx.ftz.f32) - */ -TVM_DLL const Op& ptx_rcp(); - -/*! - * \brief tvm intrinsic for PTX warp-wide any predicate (__any_sync) - */ -TVM_DLL const Op& ptx_any_sync(); - -/*! - * \brief tvm intrinsic for PTX 3-input max instruction (sm_100a+) - */ -TVM_DLL const Op& ptx_reduce3_max_f32(); - -/*! - * \brief tvm intrinsic for PTX 3-input min instruction (sm_100a+) - */ -TVM_DLL const Op& ptx_reduce3_min_f32(); - -TVM_DLL const Op& ptx_add_f32(); -TVM_DLL const Op& ptx_add_f32x2(); -TVM_DLL const Op& ptx_add_f64(); -TVM_DLL const Op& ptx_sub_f32(); -TVM_DLL const Op& ptx_sub_f32x2(); -TVM_DLL const Op& ptx_sub_f64(); -TVM_DLL const Op& ptx_mul_f32(); -TVM_DLL const Op& ptx_mul_f32x2(); -TVM_DLL const Op& ptx_mul_f64(); -TVM_DLL const Op& ptx_fma_f32(); -TVM_DLL const Op& ptx_fma_f32x2(); -TVM_DLL const Op& ptx_fma_f64(); -TVM_DLL const Op& ptx_max_f32(); - } // namespace builtin } // namespace tirx } // namespace tvm diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h index 60b292bbb265..202766571278 100644 --- a/include/tvm/tirx/op.h +++ b/include/tvm/tirx/op.h @@ -35,8 +35,6 @@ #include #include #include -#include -#include #include #include diff --git a/include/tvm/tirx/target_builtin/cuda.h b/include/tvm/tirx/target_builtin/cuda.h deleted file mode 100644 index ff10ee0b43e6..000000000000 --- a/include/tvm/tirx/target_builtin/cuda.h +++ /dev/null @@ -1,745 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/tir/target_builtin/cuda.h - * \brief TIR builtin intrinsics specific to CUDA target. - */ -#ifndef TVM_TIRX_TARGET_BUILTIN_CUDA_H_ -#define TVM_TIRX_TARGET_BUILTIN_CUDA_H_ - -#include -#include - -namespace tvm { -namespace tirx { -namespace builtin { - -// TODO(tvm-team) TensorCore specific intrinsics should be directly registered under -// cuda. namespace and used through op. -/*! - * \brief tvm intrinsic for tensor core load operators. - * - * void tvm_load_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, - * Expr index, Expr buffer_ptr, Expr stride, - * StringImm layout) { - * // m, n, k are the shape of wmma fragment. - * // Determine fragment layout(column-major or row major) by layout. - * // fragments must be in 'wmma.matrix_a' or 'wmma.matrix_b' scope. - * nvcuda::wmma::load_matrix_sync(fragment[index], buffer_ptr, stride); - * } - */ -TVM_DLL const Op& tvm_load_matrix_sync(); - -/*! - * \brief tvm intrinsic for tensor core mma_sync operators. - * - * void tvm_mma_sync(Var fragment_d, Expr index_d, - * Var fragment_a, Expr index_a, - * Var fragment_b, Expr index_b, - * Var fragment_c, Expr index_c) { - * nvcuda::wmma::mma_sync(fragment_d[index_d], fragment_a[index_a], - * fragment_b[index_b], fragment_c[index_c]); - * } - */ -TVM_DLL const Op& tvm_mma_sync(); - -/*! - * \brief tvm intrinsic for tensor core bmma_sync operators. - * - * void tvm_bmma_sync(Var fragment_d, Expr index_d, - * Var fragment_a, Expr index_a, - * Var fragment_b, Expr index_b, - * Var fragment_c, Expr index_c) { - * nvcuda::wmma::bmma_sync(fragment_d[index_d], fragment_a[index_a], - * fragment_b[index_b], fragment_c[index_c]); - * } - */ -TVM_DLL const Op& tvm_bmma_sync(); - -/*! - * \brief tvm intrinsic for tensor core fill_fragment operators. - * - * void tvm_fill_fragment(Var fragment, UIntImm m, UIntImm, n, UIntImm k, - * Expr index, Expr value) { - * // m, n, k are the shape of wmma fragment - * // fragments must be in 'wmma.accumulator' scope. - * nvcuda::wmma::fill_fragment(fragment[index], value); - * } - */ -TVM_DLL const Op& tvm_fill_fragment(); - -/*! - * \brief tvm intrinsic for tensor core store operators. - * - * void tvm_store_matrix_sync(Var fragment, UIntImm m, UIntImm, n, UIntImm k, - * Expr index, Expr buffer_ptr, Expr stride, - * StringImm layout) { - * // m, n, k are the shape of wmma fragment - * // fragments must be in 'wmma.accumulator' scope. - * nvcuda::wmma::store_matrix_sync(fragment[index], buffer_ptr, stride, layout); - * } - */ -TVM_DLL const Op& tvm_store_matrix_sync(); - -/*! - * \brief tvm intrinsic for ptx tensor core mma instructions. - * - * void ptx_mma(StringImm shape, StringImm A_layout, StringImm B_layout, - * StringImm A_dtype, StringImm B_dtype, StringImm C_dtype, - * Var multiplicand_a, Expr a_index, - * Var multiplicand_b, Expr b_index, - * Var accumulator, Expr c_index, bool saturate); - */ -TVM_DLL const Op& ptx_mma(); - -/*! - * \brief ptx mma / ldmatrix / mma_store / mma_fill variants that take - * ``(ptr_var, offset)`` pairs (not a folded access_ptr Call). Codegen - * emits ``ptr + offset`` C pointer arithmetic; ``lower_warp_memory`` - * rewrites the offset's group component to its thread-local index. - */ -TVM_DLL const Op& ptx_mma_legacy(); -TVM_DLL const Op& ptx_ldmatrix_legacy(); -TVM_DLL const Op& mma_store_legacy(); -TVM_DLL const Op& mma_fill_legacy(); - -/*! - * \brief tvm intrinsic for ptx predicate load with 32-bit data type. - * - */ -TVM_DLL const Op& ptx_ldg32(); - -/*! - * \brief tvm intrinsic for sparse tensor core ptx instructions. - * - * void ptx_mma_sp(StringImm shape, StringImm A_layout, StringImm B_layout, - * StringImm A_dtype, StringImm B_dtype, StringImm C_dtype, - * Var multiplicand_a, Expr a_index, - * Var multiplicand_b, Expr b_index, - * Var accumulator, Expr c_index, - * Var metadata, Expr meta_index, - * Var sparse_selector, bool saturate); - */ -TVM_DLL const Op& ptx_mma_sp(); - -/*! - * \brief tvm intrinsic for ptx load matrix from shared memory. - * - * void ptx_ldmatrix(Bool trans, IntImm num, StringImm type, - * Var local_ptr, Expr local_offset, - * Var smem_ptr, Expr smem_offset); - */ -TVM_DLL const Op& ptx_ldmatrix(); - -/*! - * \brief tvm intrinsics for ptx async copy from global to shared memory using cp.async - * - * void ptx_cp_async(Var shared_ptr, - * Expr shared_offset, - * Var global_ptr, - * Expr global_offset, - * size_t bytes); - */ -TVM_DLL const Op& ptx_cp_async(); - -/*! - * \brief tvm intrinsics for ptx async copy from global to shared memory using cp.async.bulk - * - * void ptx_cp_async_bulk(Var shared_ptr, - * Expr shared_offset, - * Var global_ptr, - * Expr global_offset, - * size_t bytes, - * int barrier_arr_id, - * int barrier_id); - */ -TVM_DLL const Op& ptx_cp_async_bulk(); - -/*! - * \brief tvm intrinsics for ptx async bulk copy from shared::cta to shared::cluster - * - * void ptx_cp_async_bulk_shared_to_cluster(Expr dst_ptr, - * Expr src_ptr, - * Expr size, - * Expr mbar); - */ -TVM_DLL const Op& ptx_cp_async_bulk_shared_to_cluster(); - -/*! - * \brief tvm intrinsics for ptx async copy commit and wait. - * - * void ptx_cp_async_commit_group(); - * void ptx_cp_async_wait_group(int num); - * - */ -TVM_DLL const Op& ptx_cp_async_commit_group(); -TVM_DLL const Op& ptx_cp_async_wait_group(); - -/*! - * \brief tvm intrinsics for ptx async copy barrier using cp.async.mbarrier.arrive - * - * ptx_cp_async_mbarrier_arrive(int barrier_arr_id, int barrier_id) - * - */ -TVM_DLL const Op& ptx_cp_async_mbarrier_arrive(); - -/*! - * \brief PTX fence instruction: fence.{sem}.{scope} - * - * ptx_fence(StringImm sem, StringImm scope) - */ -TVM_DLL const Op& ptx_fence(); - -/*! - * \brief PTX fence.proxy.async instruction: fence.proxy.async[.{space}] - * - * ptx_fence_proxy_async(StringImm space) - */ -TVM_DLL const Op& ptx_fence_proxy_async(); - -/*! - * \brief tvm instrinsics to call mbarrier.init.shared::cta.b64 - * - * ptx_mbarrier_init(uint64_t* bar_ptr, int thread_count) - */ -TVM_DLL const Op& ptx_mbarrier_init(); - -/*! - * \brief tvm instrinsics to call - * mbarrier.arrive.shared::cta.b64 - * or - * @p mapa.shared::cluster.u32 - * @p mbarrier.arrive.shared::cluster.b64 - */ -TVM_DLL const Op& ptx_mbarrier_arrive(); - -/*! - * \brief tvm instrinsics to call - * mbarrier.arrive.expect_tx.shared.b64 - * or - * @p mapa.shared::cluster.u32 - * @p mbarrier.arrive.expect_tx.shared.b64 - * - * ptx_mbarrier_arrive_expect_tx(uint64_t* bar_ptr, int byte_count) - */ -TVM_DLL const Op& ptx_mbarrier_arrive_expect_tx(); - -/*! - * \brief tvm instrinsics to call mbarrier.try_wait.parity repeatedly until it returns true - * - * ptx_mbarrier_try_wait(uint64_t* bar_ptr, int phase) - */ -TVM_DLL const Op& ptx_mbarrier_try_wait(); - -/*! - * \brief tvm instrinsics to call bar.arrive a, b - * - * bar_arrive(int name_bar_id, int thread_count) - */ -TVM_DLL const Op& ptx_bar_arrive(); - -/*! - * \brief tvm instrinsics to call bar.sync a, {b} - * - * bar_sync(int name_bar_id, int thread_count) - */ -TVM_DLL const Op& ptx_bar_sync(); - -/*! - * \brief tvm instrinsics to call - * cp.async.bulk.tensor.dim.shared::cluster.global.tile.mbarrier::complete_tx::bytes - * - * TMA alignment requirement: - * https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#table-alignment-multi-dim-tma - * - * ptx_cp_async_bulk_tensor_global_to_cluster(int dim, PrimExpr dst_ptr, PrimExpr bar_ptr, - * PrimExpr tensormap_addr, int...coords, int cta_mask, int cta_group, string cache_hint) - */ -TVM_DLL const Op& ptx_cp_async_bulk_tensor_global_to_cluster(); - -/*! - * \brief tvm intrinsic to call - * cp.async.bulk.tensor.dim.shared::cluster.global.tile::gather4.mbarrier::complete_tx::bytes - * - * TMA alignment requirement: - * https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#table-alignment-multi-dim-tma - * - * ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster(int dim, PrimExpr dst_ptr, PrimExpr - * bar_ptr, PrimExpr tensormap_addr, int...coords, int cta_mask, int cta_group, string cache_hint) - */ -TVM_DLL const Op& ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster(); - -/*! - * \brief tvm instrinsics to call - * cp.async.bulk.tensor.dim.global.shared::cta.tile。bulk_group - * - * TMA alignment requirement: - * https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#table-alignment-multi-dim-tma - * - * ptx_cp_async_bulk_tensor_shared_to_global(int dim, PrimExpr src_ptr, PrimExpr tensormap_addr, - * int...coords, string cache_hint) - */ -TVM_DLL const Op& ptx_cp_async_bulk_tensor_shared_to_global(); - -/*! - * \brief tvm instrinsics to call - * cp.async.bulk.prefetch.tensor.dim.L2.global.tile - * - * TMA alignment requirement: - * https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#table-alignment-multi-dim-tma - * - * ptx_cp_async_bulk_tensor_global_to_cluster_prefetch(int dim, PrimExpr tensormap_addr, - * int...coords, string cache_hint) - */ -TVM_DLL const Op& ptx_cp_async_bulk_tensor_global_to_cluster_prefetch(); - -/*! - * \brief tvm instrinsics to call - * cp.reduce.async.bulk.tensor.dim.dst.src.redOp - * - * ptx_cp_async_bulk_tensor_shared_to_global_reduce(int dim, PrimExpr src_ptr, PrimExpr - * tensormap_addr, int...coords, string cache_hint) - */ -TVM_DLL const Op& ptx_cp_async_bulk_tensor_shared_to_global_reduce(); - -/*! - * \brief tvm instrinsics to call cp.async.bulk.commit_group - * - * ptx_cp_async_bulk_commit_group() - */ -TVM_DLL const Op& ptx_cp_async_bulk_commit_group(); - -/*! - * \brief tvm instrinsics to call cp.async.bulk.wait_group{.read} N - * - * ptx_cp_async_bulk_wait_group(int N, bool read) - */ -TVM_DLL const Op& ptx_cp_async_bulk_wait_group(); - -/*! - * \brief tvm instrinsics to call barrier.cluster.arrive{.sem}{.aligned} - * - * ptx_barrier_cluster_arrive(string sem, bool aligned) - */ -TVM_DLL const Op& ptx_barrier_cluster_arrive(); - -/*! - * \brief tvm instrinsics to call barrier.cluster.wait.{acquire}{.aligned} - * - * ptx_barrier_cluster_wait(bool acquire, bool aligned) - */ -TVM_DLL const Op& ptx_barrier_cluster_wait(); - -/*! - * \brief tvm instrinsics to call elect.sync _|p, membermask and return the predicate - * - * elect_sync(membermask) - */ -TVM_DLL const Op& ptx_elect_sync(); - -/*! - * \brief PTX fence.mbarrier_init.release.cluster instruction - * - * ptx_fence_mbarrier_init() - */ -TVM_DLL const Op& ptx_fence_mbarrier_init(); - -/*! - * \brief tvm instrinsics to fetch PTX pre-defined registers - * - * ptx_fetch_register(int bits, string reg_name) - */ -TVM_DLL const Op& ptx_fetch_register(); - -/*! - * \brief PTX programmatic dependent launch synchronization. - */ -TVM_DLL const Op& ptx_griddepcontrol_wait(); -TVM_DLL const Op& ptx_griddepcontrol_launch_dependents(); - -/*! - * \brief tvm intrinsic for storing the result of PTX MMA into a destination pointer. - * For example, if each thread in a warp of size 32 has 4 elements from the result of - * m16xn8xk16 MMA in its registers, this intrinsic can be used to store the result in a - * 16x8 region in shared or global memory. - * - * There is no real PTX instruction that does that, but we want to hide details of - * complex index manipulation behind this intrinsic to simplify TIR lowering passes (e.g. - * LowerWarpMemory). - * - * void mma_store(IntImm m, IntImm n, Var dst_ptr, Var src_ptr, Expr src_offset, Var dst_stride); - */ -TVM_DLL const Op& mma_store(); - -/*! - * \brief tvm intrinsic for zero-initializing an MMA accumulation register. - * For example, if each thread in a warp of size 32 has 8 elements from the A matrix in - * m16xn8xk16 MMA in its registers, this intrinsic can be used to zero-initialize its - * 4 accumulation registers. - * - * There is no real PTX instruction that does that, but we introduce this intrinsic for the - * same reason as mma_store above. - * - * void mma_fill(IntImm local_size, Var local_ptr, Expr offset); - */ -TVM_DLL const Op& mma_fill(); - -/*! - * \brief tvm intrinsic to encode matrix descriptor for wgmma instructions. - * - * ptx_wgmma_encode_matrix_descriptor(PrimExpr ptr, PrimExpr ldo, PrimExpr sdo, int swizzle) - */ -TVM_DLL const Op& ptx_wgmma_encode_matrix_descriptor(); - -/*! - * \brief tvm intrinsic to call "" : "+r"(reg) :: "memory" - * - * ptx_wgmma_noop_barrier() - */ -TVM_DLL const Op& ptx_wgmma_noop_barrier(); - -/*! - * \brief tvm intrinsic to call wgmma.mma_async.sync.aligned.shape.dtype.atype.btype - * where both A and B are in shared memory. - * - * ptx_wgmma_mma_async_ss() - */ -TVM_DLL const Op& ptx_wgmma_mma_async_ss(); - -/*! - * \brief tvm intrinsic to call wgmma.mma_async.sync.aligned.shape.dtype.atype.btype - * where A is in register and B is in shared memory. - * - * ptx_wgmma_mma_async_rs() - */ -TVM_DLL const Op& ptx_wgmma_mma_async_rs(); - -/*! - * \brief tvm intrinsic to call wgmma.fence.sync.aligned; - * - * ptx_wgmma_fence() - */ -TVM_DLL const Op& ptx_wgmma_fence(); - -/*! - * \brief tvm intrinsic to call wgmma.commit_group.sync.aligned; - * - * ptx_wgmma_commit_group() - */ -TVM_DLL const Op& ptx_wgmma_commit_group(); - -/*! - * \brief tvm intrinsic to call wgmma.wait_group.sync.aligned; - * - * ptx_wgmma_wait_group(int N) - */ -TVM_DLL const Op& ptx_wgmma_wait_group(); - -/*! - * \brief tvm intrinsic to call stmatrix.sync.aligned.m8n8.num{.trans}.shared.b16 [p], r; - * - * ptx_stmatrix(int num, bool trans, PrimExpr ptr, PrimExpr... vars) - */ -TVM_DLL const Op& ptx_stmatrix(); - -/*! - * \brief tvm intrinsic to call setmaxnreg.action.sync.aligned.u32 imm-reg-count - */ -TVM_DLL const Op& ptx_setmaxnreg(); - -/*! - * \brief tvm intrinsic to call ld.global.acquire.gpu.b32 - * - * ptx_ld_global_acquire() - */ -TVM_DLL const Op& ptx_ld_global_acquire(); - -/*! - * \brief tvm instrinsics to call tcgen05.alloc.cta_group.sync.aligned; - * - * ptx_tcgen05_alloc(Var dst_ptr, int n_cols, int cta_group) - */ -TVM_DLL const Op& ptx_tcgen05_alloc(); - -/*! - * \brief tvm instrinsics to call tcgen05.dealloc.cta_group.sync.aligned; - * - * ptx_tcgen05_dealloc(uint32_t taddr, int n_cols, int cta_group) - */ -TVM_DLL const Op& ptx_tcgen05_dealloc(); - -/*! - * \brief tvm instrinsics to call tcgen05.relinquish_alloc_permit.cta_group.sync.aligned; - * - * ptx_tcgen05_relinquish_alloc_permit(int cta_group) - */ -TVM_DLL const Op& ptx_tcgen05_relinquish_alloc_permit(); - -/*! - * \brief tvm instrinsics to call tcgen05.fence::before_thread_sync; - * - * ptx_tcgen05_fence_before_thread_sync() - */ -TVM_DLL const Op& ptx_tcgen05_fence_before_thread_sync(); - -/*! - * \brief tvm instrinsics to call tcgen05.fence::after_thread_sync; - * - * ptx_tcgen05_fence_after_thread_sync() - */ -TVM_DLL const Op& ptx_tcgen05_fence_after_thread_sync(); - -/*! - * \brief tvm instrinsics to call tcgen05.ld.sync.aligned; - * - * ptx_tcgen05_ld() - */ -TVM_DLL const Op& ptx_tcgen05_ld(); - -/*! - * \brief tvm instrinsics to call tcgen05.st.sync.aligned; - * - * ptx_tcgen05_st() - */ -TVM_DLL const Op& ptx_tcgen05_st(); - -/*! - * \brief tvm instrinsics to call tcgen05.wait::ld.sync.aligned; - * - * ptx_tcgen05_wait_ld() - */ -TVM_DLL const Op& ptx_tcgen05_wait_ld(); - -/*! - * \brief tvm instrinsics to call tcgen05.wait::st.sync.aligned; - * - * ptx_tcgen05_wait_st() - */ -TVM_DLL const Op& ptx_tcgen05_wait_st(); - -/*! - * \brief tvm intrinsic to encode matrix descriptor for tcgen05 instructions. - * - * ptx_tcgen05_encode_matrix_descriptor(PrimExpr ptr, PrimExpr ldo, PrimExpr sdo, int swizzle) - */ -TVM_DLL const Op& ptx_tcgen05_encode_matrix_descriptor(); - -/*! - * \brief tvm intrinsic to encode instruction descriptor for tcgen05 MMA. - * - * ptx_tcgen05_encode_instr_descriptor(PrimExpr desc, string d_dtype, string a_dtype, string - * b_dtype, int M, int N, int K, bool trans_a, bool trans_b, int n_cta_groups, bool neg_a, bool - * neg_b, bool sat_d, bool is_sparse) - */ -TVM_DLL const Op& ptx_tcgen05_encode_instr_descriptor(); - -/*! - * \brief tvm intrinsic to encode instruction descriptor for tcgen05 MMA block scaled. - * - * ptx_tcgen05_encode_instr_descriptor_block_scaled(PrimExpr desc, string d_dtype, - * string a_dtype, string b_dtype, string sfa_dtype, string stb_dtype, - * int M, int N, int K, bool trans_a, bool trans_b, - * int n_cta_groups, bool neg_a, bool neg_b, bool is_sparse) - */ -TVM_DLL const Op& ptx_tcgen05_encode_instr_descriptor_block_scaled(); - -/*! - * \brief tvm intrinsic to call tcgen05.mma.cta_group.kind without block scaling. - * - * ptx_tcgen05_mma() - */ -TVM_DLL const Op& ptx_tcgen05_mma(); - -/*! - * \brief tvm intrinsic to call tcgen05.mma.cta_group.kind.block_scale{.scale_vec_size} - * - * ptx_tcgen05_mma_block_scale() - */ -TVM_DLL const Op& ptx_tcgen05_mma_block_scale(); - -/*! - * \brief tvm intrinsic to call tcgen05.mma.sp.cta_group.kind without block scaling. - * - * ptx_tcgen05_mma_sp() - */ -TVM_DLL const Op& ptx_tcgen05_mma_sp(); - -/*! - * \brief tvm intrinsic to call tcgen05.mma.sp.cta_group.kind.block_scale{.scale_vec_size} - * - * ptx_tcgen05_mma_sp_block_scale() - */ -TVM_DLL const Op& ptx_tcgen05_mma_sp_block_scale(); - -/*! - * \brief tvm instrinsics to call tcgen05.commit.cta_group - * - * ptx_tcgen05_commit() - */ -TVM_DLL const Op& ptx_tcgen05_commit(); - -/*! - * \brief tvm instrinsics to call tcgen05.cp.cta_group - * - * ptx_tcgen05_cp() - */ -TVM_DLL const Op& ptx_tcgen05_cp(); - -/*! - * \brief tvm instrinsics to call tcgen05.shift.cta_group.down - * - * ptx_tcgen05_shift() - */ -TVM_DLL const Op& ptx_tcgen05_shift(); - -/*! - * \brief tvm instrinsics to call map_shared_rank - * - * ptx_map_shared_rank(PrimExpr ptr, int rank) - */ -TVM_DLL const Op& ptx_map_shared_rank(); - -/*! - * \brief tvm instrinsics to call a CUDA function. Source code is provided as a string. - * - * cuda_func_call(String func_name, PrimExpr... args, String source_code) - */ -TVM_DLL const Op& cuda_func_call(); - -/*! - * \brief nvshmem intrinsics for nvshmem_my_pe() operation. - * - * int nvshmem_my_pe() - */ -TVM_DLL const Op& nvshmem_my_pe(); - -/*! - * \brief nvshmem intrinsics for nvshmem_n_pes() operation. - * - * int nvshmem_n_pes() - */ -TVM_DLL const Op& nvshmem_n_pes(); - -/*! - * \brief nvshmem intrinsics for nvshmem_getmem_nbi() operation. - * - * void nvshmem_getmem_nbi(void *dest, const void *source, size_t nelems, int pe) - */ -TVM_DLL const Op& nvshmem_getmem_nbi(); - -/*! - * \brief nvshmem intrinsics for nvshmem_putmem_nbi() operation. - * - * void nvshmem_putmem_nbi(void *dest, const void *source, size_t nelems, int pe) - */ -TVM_DLL const Op& nvshmem_putmem_nbi(); - -/*! - * \brief nvshmem intrinsics for nvshmemx_getmem_nbi_warp() operation. - * - * void nvshmemx_getmem_nbi_warp(void *dest, const void *source, size_t nelems, int pe) - */ -TVM_DLL const Op& nvshmem_getmem_nbi_warp(); - -/*! - * \brief nvshmem intrinsics for nvshmemx_putmem_nbi_warp() operation. - * - * void nvshmemx_putmem_nbi_warp(void *dest, const void *source, size_t nelems, int pe) - */ -TVM_DLL const Op& nvshmem_putmem_nbi_warp(); - -/*! - * \brief nvshmem intrinsics for nvshmemx_getmem_nbi_block() operation. - * - * void nvshmemx_getmem_nbi_block(void *dest, const void *source, size_t nelems, int pe) - */ -TVM_DLL const Op& nvshmem_getmem_nbi_block(); - -/*! - * \brief nvshmem intrinsics for nvshmemx_putmem_nbi_block() operation. - * - * void nvshmemx_putmem_nbi_block(void *dest, const void *source, size_t nelems, int pe) - */ -TVM_DLL const Op& nvshmem_putmem_nbi_block(); - -/*! - * \brief nvshmem intrinsics for nvshmemx_signal_op() operation. - * - * void nvshmemx_signal_op(uint64_t *sig_addr, uint64_t signal, int sig_op, int pe) - */ -TVM_DLL const Op& nvshmem_signal_op(); - -/*! - * \brief nvshmem intrinsics for nvshmem_FuncParam{TYPENAME}_wait_until() operation. - * - * void nvshmem_FuncParam{TYPENAME}_wait_until(TYPE *ivar, int cmp, TYPE cmp_value) - */ -TVM_DLL const Op& nvshmem_wait_until(); - -/*! - * \brief nvshmem intrinsics for nvshmem_quiet() operation. - * - * void nvshmem_quiet() - */ -TVM_DLL const Op& nvshmem_quiet(); - -/*! - * \brief nvshmem intrinsics for nvshmemx_putmem_signal_nbi() operation. - * - * void nvshmemx_putmem_signal_nbi(void *dest, const void *source, size_t nelems, uint64_t - * *sig_addr, uint64_t signal, int sig_op, int pe) - */ -TVM_DLL const Op& nvshmem_putmem_signal_nbi(); - -/*! - * \brief nvshmem intrinsics for nvshmemx_putmem_signal_nbi_warp() operation. - * - * void nvshmemx_putmem_signal_nbi_warp(void *dest, const void *source, size_t nelems, uint64_t - * *sig_addr, uint64_t signal, int sig_op, int pe) - */ -TVM_DLL const Op& nvshmem_putmem_signal_nbi_warp(); - -/*! - * \brief nvshmem intrinsics for nvshmemx_putmem_signal_nbi_block() operation. - * - * void nvshmemx_putmem_signal_nbi_block(void *dest, const void *source, size_t nelems, - * uint64_t *sig_addr, uint64_t signal, int sig_op, int pe) - */ -TVM_DLL const Op& nvshmem_putmem_signal_nbi_block(); - -/*! - * \brief nvshmem intrinsics for nvshmem_fence() operation. - * - * void nvshmem_fence() - */ -TVM_DLL const Op& nvshmem_fence(); - -/*! - * \brief nvshmem intrinsics for nvshmem_barrier_all() operation. - * - * void nvshmem_barrier_all() - */ -TVM_DLL const Op& nvshmem_barrier_all(); - -} // namespace builtin -} // namespace tirx -} // namespace tvm - -#endif // TVM_TIRX_TARGET_BUILTIN_CUDA_H_ diff --git a/include/tvm/tirx/target_builtin/trn.h b/include/tvm/tirx/target_builtin/trn.h deleted file mode 100644 index 556156bc13a9..000000000000 --- a/include/tvm/tirx/target_builtin/trn.h +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tvm/tir/target_builtin/trn.h - * \brief TIR builtin intrinsics specific to Trainium target. - */ -#ifndef TVM_TIRX_TARGET_BUILTIN_TRN_H_ -#define TVM_TIRX_TARGET_BUILTIN_TRN_H_ - -#include -#include - -namespace tvm { -namespace tirx { -namespace builtin { - -/*! - * \brief nki intrinsics for load operation. - * - * nki_load(result, data) - */ -TVM_DLL const Op& nki_load(); -/*! - * \brief nki intrinsics for store operation. - * - * nki_store(result, data) - */ -TVM_DLL const Op& nki_store(); -/*! - * \brief nki intrinsics for tensor_copy operation. - * - * nki_tensor_copy(result, data) - */ -TVM_DLL const Op& nki_tensor_copy(); -/*! - * \brief nki intrinsics for matmul operation. - * - * nki_matmul(C, A, B, accum) - * - * equivalent to C += A.T @ B (if accum is true), or C = A.T @ B (if accum is false) - */ -TVM_DLL const Op& nki_matmul(); - -/*! - * \brief nki intrinsics for activation operation. - * - * nki_activation(result, data, opcode, bias, scale) - */ -TVM_DLL const Op& nki_activation(); - -/*! - * \brief nki intrinsics for reciprocal operation. - * - * nki_reciprocal(result, data) - */ -TVM_DLL const Op& nki_reciprocal(); - -/*! - * \brief nki intrinsics for tensortensor operation. - * - * nki_tensortensor(result, operand0, operand1, opcode) - */ -TVM_DLL const Op& nki_tensortensor(); - -/*! - * \brief nki intrinsics for tensorscalar operation. - * - * nki_tensorscalar(result, operand0, operand1, opcode, reverse) - */ -TVM_DLL const Op& nki_tensorscalar(); - -/*! - * \brief nki intrinsics for tensorreduce operation. - * - * nki_tensorreduce(result, data, opcode, negate, axes) - */ -TVM_DLL const Op& nki_tensorreduce(); - -/*! - * \brief nki intrinsics for memset operation. - * - * nki_memset(result, value) - */ -TVM_DLL const Op& nki_memset(); - -/*! - * \brief nki intrinsics for activation reduce operation. - * - * nki_activation_reduce(reduce_res, act_res, data, opcode, reduce_opcode, bias, scale) - */ -TVM_DLL const Op& nki_activation_reduce(); - -/*! - * \brief nki intrinsics for tensorscalar reduce operation. - * - * nki_tensorscalar_reduce(reduce_res, tensorscalar_res, operand0, operand1, opcode, reduce_opcode, - * reverse) - */ -TVM_DLL const Op& nki_tensorscalar_reduce(); - -/*! - * \brief nki intrinsics for initializing identity tensor. - * - * nki_identity(result, size) - */ -TVM_DLL const Op& nki_identity(); - -/*! - * \brief nki intrinsics for scalar tensor tensor operation. - * - * (data op1 operand1) op2 (operand2) where op1 is tensor-scalar and op2 is tensor-tensor - * - * nki_scalar_tensor_tensor(result, data, operand0, operand1, opcode0, opcode1, reverse0, reverse1) - * - */ -TVM_DLL const Op& nki_scalar_tensor_tensor(); - -/*! - * \brief nki intrinsics for scalar tensor scalar operation. - * - * (data op1 operand1) op2 (operand2) where op1 and op2 are tensor-scalar - * - * nki_scalar_tensor_scalar(result, data, operand0, operand1, opcode0, opcode1, reverse0, reverse1) - * - */ -TVM_DLL const Op& nki_scalar_tensor_scalar(); - -/*! - * \brief nki intrinsics for affine_select operation. - * - * nki_affine_select(result, pred, true_value, false_value) - */ -TVM_DLL const Op& nki_affine_select(); - -} // namespace builtin -} // namespace tirx -} // namespace tvm - -#endif // TVM_TIRX_TARGET_BUILTIN_TRN_H_ diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index bb171170fae3..b38e86a0ae69 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -51,6 +51,9 @@ # tvm.tirx — registers itself via tvm.script.register_dialect in its __init__ from . import tirx +# tvm.backend — owns backend Python load hooks +from . import backend + # tvm.target from . import target @@ -72,10 +75,7 @@ # Relax contain modules that are only available in compiler package # Do not import them if TVM is built with runtime only if not _RUNTIME_ONLY: - # tile_primitive imports both Python Op class declarations (Zero, Add, ...) - # and per-target dispatch schedule registrations. Must run before relax so - # any relax pass that looks up a schedule sees them. - from .tirx.operator import tile_primitive + backend.load_all() # tvm.relax — registers itself via tvm.script.register_dialect in its __init__ from . import relax diff --git a/python/tvm/backend/__init__.py b/python/tvm/backend/__init__.py new file mode 100644 index 000000000000..2243f5e03410 --- /dev/null +++ b/python/tvm/backend/__init__.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Backend-owned Python modules and load hooks.""" + +from __future__ import annotations + +import importlib +import importlib.util +import sys +import types +from pkgutil import extend_path +from typing import Any + +__path__ = extend_path(__path__, __name__) # type: ignore[name-defined] + +_BUILTIN_BACKENDS = ( + "cuda", + "metal", + "rocm", + "trn", + "opencl", + "vulkan", + "webgpu", + "hexagon", + "adreno", +) +_LOADED_BACKENDS: dict[str, Any] = {} + + +class _AliasModule(types.ModuleType): + """Module object that exposes a backend module under a public alias.""" + + def __init__(self, fullname: str, module): + super().__init__(fullname, getattr(module, "__doc__", None)) + self.__dict__["__tvm_backend_module__"] = module + self.__dict__["__package__"] = fullname.rpartition(".")[0] + if hasattr(module, "__all__"): + self.__dict__["__all__"] = module.__all__ + if hasattr(module, "__path__"): + self.__dict__["__path__"] = [] + + def __getattr__(self, name: str): + return getattr(self.__dict__["__tvm_backend_module__"], name) + + def __setattr__(self, name: str, value): + setattr(self.__dict__["__tvm_backend_module__"], name, value) + + def __delattr__(self, name: str): + delattr(self.__dict__["__tvm_backend_module__"], name) + + def __dir__(self): + return sorted(set(super().__dir__()) | set(dir(self.__dict__["__tvm_backend_module__"]))) + + +class _AliasLoader: + """Loader that returns an already-resolved module for an alias spec.""" + + def __init__(self, fullname: str, module): + self._fullname = fullname + self._module = module + + def create_module(self, spec): + return _get_alias_module(self._fullname, self._module) + + def exec_module(self, module): + _set_module_alias(self._fullname, self._module) + return None + + def is_package(self, fullname): + return hasattr(self._module, "__path__") + + +def _redirect_tirx_backend_alias(fullname: str) -> str | None: + prefix = "tvm.tirx." + if not fullname.startswith(prefix): + return None + rest = fullname[len(prefix) :] + backend_name, sep, tail = rest.partition(".") + if not sep or backend_name not in _LOADED_BACKENDS: + return None + return f"tvm.backend.{backend_name}.{tail}" + + +class _BackendAliasFinder: + """Redirect ``tvm.tirx..*`` imports to ``tvm.backend..*``.""" + + @classmethod + def find_spec(cls, fullname, path, target=None): + redirected = _redirect_tirx_backend_alias(fullname) + if redirected is None: + return None + module = importlib.import_module(redirected) + _set_module_alias(fullname, module) + loader = _AliasLoader(fullname, module) + spec = importlib.util.spec_from_loader( + fullname, loader, is_package=hasattr(module, "__path__") + ) + if spec is not None and hasattr(module, "__path__"): + spec.submodule_search_locations = [] + return spec + + +if not any(isinstance(finder, _BackendAliasFinder) for finder in sys.meta_path): + sys.meta_path.insert(0, _BackendAliasFinder()) + + +def _get_alias_module(alias: str, module): + existing = sys.modules.get(alias) + if ( + isinstance(existing, _AliasModule) + and existing.__dict__.get("__tvm_backend_module__") is module + ): + return existing + return _AliasModule(alias, module) + + +def _set_module_alias(alias: str, module, *, direct: bool = False) -> None: + alias_module = module if direct else _get_alias_module(alias, module) + sys.modules[alias] = alias_module + parent_name, _, child_name = alias.rpartition(".") + parent = sys.modules.get(parent_name) + if parent is not None: + setattr(parent, child_name, alias_module) + + +def _alias_loaded_backend_modules(name: str) -> None: + backend_prefix = f"tvm.backend.{name}" + public_prefix = f"tvm.tirx.{name}" + for module_name, module in sorted(list(sys.modules.items())): + if module_name == backend_prefix or module_name.startswith(f"{backend_prefix}."): + public_name = f"{public_prefix}{module_name[len(backend_prefix) :]}" + _set_module_alias(public_name, module, direct=module_name == backend_prefix) + + +def _import_backend(name: str): + module_name = f"tvm.backend.{name}" + try: + return importlib.import_module(module_name) + except ModuleNotFoundError as err: + if err.name == module_name: + raise ImportError( + f"Cannot load TVM backend {name!r}: expected Python package {module_name!r}. " + "Install the backend package or check the backend name." + ) from err + raise + + +def load(name: str) -> None: + """Load a backend's Python registration hooks. + + Loading is idempotent. A backend package must live at ``tvm.backend.`` + and expose ``register_backend()``. + """ + + if name in _LOADED_BACKENDS: + return None + + module = _import_backend(name) + register_backend = getattr(module, "register_backend", None) + if register_backend is None: + raise AttributeError(f"Backend package 'tvm.backend.{name}' has no register_backend()") + + import tvm.tirx as tirx # pylint: disable=import-outside-toplevel + + setattr(tirx, name, module) + sys.modules[f"tvm.tirx.{name}"] = module + _LOADED_BACKENDS[name] = module + try: + register_backend() + _alias_loaded_backend_modules(name) + except Exception: + _LOADED_BACKENDS.pop(name, None) + if getattr(tirx, name, None) is module: + delattr(tirx, name) + if sys.modules.get(f"tvm.tirx.{name}") is module: + sys.modules.pop(f"tvm.tirx.{name}", None) + raise + return None + + +def load_all() -> None: + """Load all in-tree backend Python hooks.""" + + for name in _BUILTIN_BACKENDS: + load(name) + return None + + +def is_loaded(name: str) -> bool: + """Return whether a backend has been loaded.""" + + return name in _LOADED_BACKENDS + + +__all__ = ["is_loaded", "load", "load_all"] diff --git a/python/tvm/backend/adreno/__init__.py b/python/tvm/backend/adreno/__init__.py new file mode 100644 index 000000000000..8f79af80d183 --- /dev/null +++ b/python/tvm/backend/adreno/__init__.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Adreno-owned backend hooks.""" + +from importlib import import_module + +_LAZY_SUBMODULES = {"target_tags"} + + +def register_backend(): + """Register Adreno-owned Python semantics.""" + import tvm.backend as backend # pylint: disable=import-outside-toplevel + + backend.load("opencl") + backend.load("vulkan") + import_module(f"{__name__}.target_tags") + + +def __getattr__(name: str): + if name in _LAZY_SUBMODULES: + return import_module(f"{__name__}.{name}") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = ["register_backend", "target_tags"] diff --git a/python/tvm/target/tag_registry/adreno.py b/python/tvm/backend/adreno/target_tags.py similarity index 97% rename from python/tvm/target/tag_registry/adreno.py rename to python/tvm/backend/adreno/target_tags.py index 388253c3bd7c..daba9acf9d80 100644 --- a/python/tvm/target/tag_registry/adreno.py +++ b/python/tvm/backend/adreno/target_tags.py @@ -16,7 +16,7 @@ # under the License. """Qualcomm Adreno GPU target tags.""" -from .registry import register_tag +from tvm.target import register_tag register_tag( "qcom/adreno-opencl", diff --git a/python/tvm/backend/cuda/__init__.py b/python/tvm/backend/cuda/__init__.py new file mode 100644 index 000000000000..34592d2de3e7 --- /dev/null +++ b/python/tvm/backend/cuda/__init__.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""CUDA-owned TIRx modules.""" + +from importlib import import_module + +_LAZY_SUBMODULES = {"lang", "op", "operator", "script", "target_tags"} + + +def register_backend(): + """Register CUDA-owned Python semantics.""" + from tvm.tirx.script.builder import ir as builder_ir # pylint: disable=import-outside-toplevel + + for name, namespace in script_namespaces().items(): + builder_ir.register_script_namespace(name, namespace) + + import_module(f"{__name__}.operator.intrinsics") + import_module(f"{__name__}.operator.tile_primitive") + import_module(f"{__name__}.target_tags") + + +def script_namespaces(**_): + """Return CUDA-owned TVMScript namespaces.""" + from .script import ( # pylint: disable=import-outside-toplevel + CUDANamespace, + NVSHMEMNamespace, + PTXNamespace, + ) + + return { + "cuda": CUDANamespace(), + "nvshmem": NVSHMEMNamespace(), + "ptx": PTXNamespace(), + } + + +def script_namespace(**kwargs): + """Return the CUDA TVMScript namespace object.""" + return script_namespaces(**kwargs)["cuda"] + + +def __getattr__(name: str): + if name in _LAZY_SUBMODULES: + return import_module(f"{__name__}.{name}") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = [ + "lang", + "op", + "operator", + "register_backend", + "script", + "script_namespace", + "script_namespaces", + "target_tags", +] diff --git a/python/tvm/backend/cuda/lang/__init__.py b/python/tvm/backend/cuda/lang/__init__.py new file mode 100644 index 000000000000..a70ce50196c6 --- /dev/null +++ b/python/tvm/backend/cuda/lang/__init__.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""CUDA-specific TIRx language helpers.""" + +from importlib import import_module + +__all__ = [ + "BaseTileScheduler", + "ClusterPersistentScheduler2D", + "FlashAttentionLPTScheduler", + "FlashAttentionLinearScheduler", + "GroupMajor3D", + "IndexedTripleTileScheduler", + "MBarrier", + "Pipeline", + "PipelineState", + "RankAwareGroupMajorTileScheduler", + "SMEMPool", + "SmemDescriptor", + "TCGen05Bar", + "TMABar", + "TMEMPool", + "TMEMStages", + "WarpRole", + "WarpgroupRole", +] + +_HELPER_MODULES = { + "MBarrier": ".pipeline", + "Pipeline": ".pipeline", + "PipelineState": ".pipeline", + "BaseTileScheduler": ".tile_scheduler", + "ClusterPersistentScheduler2D": ".tile_scheduler", + "FlashAttentionLPTScheduler": ".tile_scheduler", + "FlashAttentionLinearScheduler": ".tile_scheduler", + "GroupMajor3D": ".tile_scheduler", + "IndexedTripleTileScheduler": ".tile_scheduler", + "RankAwareGroupMajorTileScheduler": ".tile_scheduler", + "SMEMPool": ".alloc_pool", + "TCGen05Bar": ".pipeline", + "TMABar": ".pipeline", + "TMEMPool": ".alloc_pool", + "TMEMStages": ".alloc_pool", + "WarpRole": ".warp_role", + "WarpgroupRole": ".warp_role", + "SmemDescriptor": ".smem_desc", +} + + +def __getattr__(name: str): + module_name = _HELPER_MODULES.get(name) + if module_name is None: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + value = getattr(import_module(module_name, __name__), name) + globals()[name] = value + return value diff --git a/python/tvm/backend/cuda/lang/alloc_pool.py b/python/tvm/backend/cuda/lang/alloc_pool.py new file mode 100644 index 000000000000..725264a35b16 --- /dev/null +++ b/python/tvm/backend/cuda/lang/alloc_pool.py @@ -0,0 +1,529 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""SMEM and TMEM bump-allocator pools for TIRX kernels.""" + +from __future__ import annotations + +import functools +import operator + +from tvm import DataType +from tvm.tirx.layout import S, TCol, TileLayout, TLane + +# --------------------------------------------------------------------------- +# ir_builder helpers — imported lazily to avoid circular deps at module level +# --------------------------------------------------------------------------- + +_ir = None + + +def _get_ir(): + global _ir + if _ir is None: + from tvm.tirx.script.builder import ir as _mod + + _ir = _mod + return _ir + + +def _get_frame(): + from tvm.tirx.script.builder import frame + + return frame + + +# --------------------------------------------------------------------------- +# Shared utilities +# --------------------------------------------------------------------------- + +_POOL_UNSET = object() + + +def _default_tmem_layout(rows, cols): + return TileLayout(S[(rows, cols) : (1 @ TLane, 1 @ TCol)]) + + +def _emit_stmt(expr): + ir = _get_ir() + ir.add_to_parent(ir.evaluate(expr)) + + +def _shape_product(shape): + return functools.reduce(operator.mul, shape, 1) + + +def _auto_swizzle_mode(dtype): + """Select the default MMA swizzle mode for a shared-memory allocation.""" + from tvm.backend.cuda.operator.tile_primitive.tma_utils import SwizzleMode + + del dtype + return SwizzleMode.SWIZZLE_128B_ATOM + + +def _swizzle_atom_bytes(swizzle_mode): + """Return the row width (in bytes) of one swizzle atom for *swizzle_mode*.""" + from tvm.backend.cuda.operator.tile_primitive.tma_utils import SwizzleMode + + return { + SwizzleMode.SWIZZLE_NONE: 0, + SwizzleMode.SWIZZLE_32B_ATOM: 32, + SwizzleMode.SWIZZLE_64B_ATOM: 64, + SwizzleMode.SWIZZLE_128B_ATOM: 128, + }[swizzle_mode] + + +def _suggest_swizzle_for_row_bytes(row_bytes): + """Pick the largest valid swizzle mode whose atom row fits within *row_bytes*.""" + + for atom_bytes, mode in ( + (128, "SWIZZLE_128B_ATOM"), + (64, "SWIZZLE_64B_ATOM"), + (32, "SWIZZLE_32B_ATOM"), + ): + if row_bytes >= atom_bytes and row_bytes % atom_bytes == 0: + return mode + return "SWIZZLE_NONE" + + +def _validate_mma_alloc_shape(shape, dtype, swizzle_mode): + """Validate that *shape* / *dtype* / *swizzle_mode* are mutually compatible. + + ``mma_shared_layout`` tiles a swizzle atom of shape ``[8, swizzle_bytes / dtype_bytes]`` + over the last two logical dimensions of *shape*. If the row width or row count of + the request is smaller than (or not a multiple of) the atom, the underlying + ``Layout.tile_to`` lowers to a ``floordiv``/``floormod`` by zero and raises an + opaque internal "Divide by zero" diagnostic from ``tile_tile_ops.cc``. Catch the + misconfiguration here so callers see *what* is wrong and *how* to fix it. + + Validation skipped when *swizzle_mode* is ``SWIZZLE_NONE`` (no atom). + """ + from tvm.backend.cuda.operator.tile_primitive.tma_utils import SwizzleMode + + if swizzle_mode == SwizzleMode.SWIZZLE_NONE: + return + + if len(shape) < 2: + raise ValueError( + f"alloc_mma shape={tuple(shape)} has fewer than 2 dimensions; " + f"swizzled MMA layouts tile over the last two dims (rows, cols). " + f"Use swizzle_mode='none' for 1-D allocations." + ) + + # Only validate concrete int dims; symbolic dims fall through (the analyzer + # in C++ will still ICHECK on them, but at least we don't false-positive). + rows = shape[-2] + cols = shape[-1] + if not (isinstance(rows, int) and isinstance(cols, int)): + return + + dtype_bytes = DataType(dtype).bits // 8 + if dtype_bytes == 0: + # Sub-byte dtype (e.g. float4); ``cols`` is already in element units, so + # use a fractional check expressed via bits. + col_bits = cols * DataType(dtype).bits + atom_bits = _swizzle_atom_bytes(swizzle_mode) * 8 + if col_bits < atom_bits or col_bits % atom_bits != 0: + row_bytes = col_bits // 8 if col_bits % 8 == 0 else col_bits / 8 + atom_bytes = _swizzle_atom_bytes(swizzle_mode) + suggestion = _suggest_swizzle_for_row_bytes(col_bits // 8 if col_bits >= 8 else 0) + raise ValueError( + f"alloc_mma shape={tuple(shape)} with dtype={dtype!r} produces " + f"{row_bytes}B rows, which is incompatible with the {atom_bytes}B " + f"swizzle atom selected by {swizzle_mode.name}. " + f"Use swizzle_mode=SwizzleMode.{suggestion}, or widen shape[-1] " + f"to a multiple of " + f"{(atom_bits + DataType(dtype).bits - 1) // DataType(dtype).bits} elements." + ) + else: + row_bytes = cols * dtype_bytes + atom_bytes = _swizzle_atom_bytes(swizzle_mode) + if row_bytes < atom_bytes or row_bytes % atom_bytes != 0: + suggestion = _suggest_swizzle_for_row_bytes(row_bytes) + min_cols = atom_bytes // dtype_bytes + raise ValueError( + f"alloc_mma shape={tuple(shape)} with dtype={dtype!r} produces " + f"{row_bytes}B rows, which is incompatible with the {atom_bytes}B " + f"swizzle atom selected by {swizzle_mode.name}. " + f"Use swizzle_mode=SwizzleMode.{suggestion}, or widen shape[-1] " + f"to a multiple of {min_cols} elements (>= {atom_bytes}B at {dtype})." + ) + + # Atom rows is always 8 (see ``mma_atom_shape`` in tma_utils.py). + atom_rows = 8 + if rows < atom_rows or rows % atom_rows != 0: + raise ValueError( + f"alloc_mma shape={tuple(shape)} has shape[-2]={rows}, but the " + f"{swizzle_mode.name} atom requires shape[-2] to be a positive " + f"multiple of {atom_rows}. Use swizzle_mode='none', or widen shape[-2] " + f"to a multiple of {atom_rows}." + ) + + +# --------------------------------------------------------------------------- +# TMEMStages +# --------------------------------------------------------------------------- + + +def _meta_class(cls): + """Apply @meta_class decorator from ir_builder.""" + return _get_ir().meta_class(cls) + + +@_meta_class +class TMEMStages: + """Parse-time staged view over a TMEM buffer. + + Parameters + ---------- + buf : Buffer + The underlying TMEM buffer (e.g. f32 or f16 view). + col_start : int + First column of stage 0 in *buf*'s column space. + width : int + Number of columns per stage. + stages : int + Number of pipeline stages (default 1). + stride : int or None + Column distance between consecutive stages. When *None* (default), + equals *width* (stages are packed back-to-back). + """ + + def __init__(self, buf, col_start, width, stages=1, stride=None): + self.buf = buf + self.col_start = col_start + self.width = width + self.stages = stages + self.stride = width if stride is None else stride + + def _stage_base(self, stage): + return self.col_start + stage * self.stride + + def __getitem__(self, item): + if isinstance(item, tuple): + assert len(item) == 2, "TMEMStages expects region[stage] or region[stage, start:stop]" + stage, col_slice = item + assert isinstance(col_slice, slice), "TMEMStages tuple indexing requires a slice" + base = self._stage_base(stage) + start = 0 if col_slice.start is None else col_slice.start + stop = self.width if col_slice.stop is None else col_slice.stop + return self.buf[:, base + start : base + stop : col_slice.step] + base = self._stage_base(item) + return self.buf[:, base : base + self.width] + + +# --------------------------------------------------------------------------- +# TMEMPool +# --------------------------------------------------------------------------- + + +@_meta_class +class TMEMPool: + """Bump allocator over TMEM columns.""" + + def __init__( + self, + pool, + total_cols=512, + *, + cta_group=1, + alloc_warp=0, + dealloc_warp=None, + tmem_addr=None, + sync_after_alloc=True, + ): + # tcgen05 alloc/dealloc are warp-uniform PTX instructions: every lane + # in the chosen warp must participate, and exactly one warp in the + # CTA must execute them. The pool emits its own + # ``if warp_id() == target_warp: tcgen05.alloc(...)`` + # guard, using the cta->warp scope id ``T.warp_id()``. + # NOTE: synccheck currently false-deadlocks on kernels that declare a + # second warp-scope id (cpusim binds only one warp var); the generated + # CUDA is equivalent to ``thread_rank() // 32 == target_warp``. + self.pool = pool + self.total_cols = total_cols + self.cta_group = cta_group + self.alloc_warp = alloc_warp + self.dealloc_warp = alloc_warp if dealloc_warp is None else dealloc_warp + self.sync_after_alloc = sync_after_alloc + self.offset = 0 + self.max_offset = 0 + self._committed = False + self._deallocated = False + self._addr_buf = pool.alloc([1], "uint32", align=4) if tmem_addr is None else tmem_addr + + def _addr_slot(self): + try: + return self._addr_buf[0] + except TypeError: + return self._addr_buf + + @property + def addr(self): + return self._addr_slot() + + def _emit_warp_guard(self, target_warp, emit): + from tvm.script import tirx as T + + warp_id = T.warp_id() + with T.If(warp_id == target_warp): + with T.Then(): + emit() + + def _resolve_cols(self, shape, dtype, cols, layout=None): + if cols is not None: + return cols + bits = DataType(dtype).bits + if layout is not None: + # span("TCol") is in *element* (buffer dtype) units; one TMEM cell + # holds 32 bits regardless of the element type. + tcol_elems = int(layout.span("TCol")) + tcol_bits = tcol_elems * bits + assert tcol_bits % 32 == 0, ( + f"layout TCol span={tcol_elems} elems x {bits}b is not 32-bit aligned" + ) + return tcol_bits // 32 + assert len(shape) == 2, "TMEMPool.alloc() requires cols= for non-2D TMEM buffers" + total_bits = _shape_product(shape) * bits + rows = shape[0] + assert total_bits % (32 * rows) == 0, ( + f"Cannot infer TMEM columns from shape={shape}, dtype={dtype!r}; " + "please pass cols= explicitly" + ) + return total_bits // (32 * rows) + + def alloc(self, shape, dtype="float32", *, layout=None, cols=None, datapath=None): + """Allocate a TMEM buffer. + + Parameters + ---------- + shape, dtype, cols + Standard buffer shape / dtype / column count. + layout + Explicit ``TileLayout``. Mutually exclusive with ``datapath``. + datapath : str | None + Optional tcgen05 datapath letter (``"D"`` for M=128 full datapath, + ``"F"`` for M=64 non-``.ws`` scattered). When provided, the buffer's + layout is derived from ``tmem_datapath_layout(datapath, *shape)`` + so the row index reflects the *physical* TMEM lane occupation + (PTX ISA §9.7.16.10.5). The downstream ``.16x*b`` / ``.32x32b`` + dispatches structurally check this layout to catch mismatched + atoms (e.g. a ``.16x*b`` M=128 read against a Layout F buffer). + Defaults to ``None``, which means Layout D's identity row→lane + mapping — keep this for shape ``(128, X)`` buffers that hold + an M=128 MMA accumulator. + """ + from tvm.tirx.layout import tmem_datapath_layout + + if layout is not None and datapath is not None: + raise ValueError("TMEMPool.alloc: pass at most one of layout= and datapath=") + if datapath is not None: + assert len(shape) == 2, "TMEMPool.alloc: datapath= requires a 2-D shape" + layout = tmem_datapath_layout(datapath, shape[0], shape[1]) + + ir = _get_ir() + cols = self._resolve_cols(shape, dtype, cols, layout) + col_start = self.offset + col_end = col_start + cols + assert col_end <= self.total_cols, f"TMEM overflow: {col_end} > {self.total_cols}" + if layout is None: + assert len(shape) == 2, "TMEMPool.alloc() requires layout= for non-2D TMEM buffers" + layout = _default_tmem_layout(shape[0], shape[1]) + res = ir.decl_buffer(shape, dtype, scope="tmem", allocated_addr=col_start, layout=layout) + self.offset = col_end + self.max_offset = max(self.max_offset, self.offset) + return res + + def alloc_sf(self, shape, dtype, *, sf_per_mma, sf_reuse=1): + """Allocate a tcgen05 block-scaled SF TMEM buffer with an inferred layout. + + ``shape`` last two dims are ``(rows, SF_K * sf_reuse)`` (the last dim is + what gemm dispatch iterates over). When ``shape`` has 3 dims, the first + is treated as a pipe-depth outer. + """ + from tvm.backend.cuda.operator.tile_primitive.gemm_async.tcgen05 import sf_tmem_layout + + if len(shape) == 2: + pipe_depth, rows, last = None, shape[0], shape[1] + elif len(shape) == 3: + pipe_depth, rows, last = shape[0], shape[1], shape[2] + else: + raise ValueError( + f"alloc_sf expects 2D (rows, SF_K*sf_reuse) or 3D " + f"(pipe_depth, rows, SF_K*sf_reuse); got shape={shape}" + ) + assert last % sf_reuse == 0, ( + f"alloc_sf: shape last dim {last} must be divisible by sf_reuse={sf_reuse}" + ) + SF_K = last // sf_reuse + layout = sf_tmem_layout( + rows=rows, SF_K=SF_K, sf_per_mma=sf_per_mma, sf_reuse=sf_reuse, pipe_depth=pipe_depth + ) + return self.alloc(shape, dtype, layout=layout) + + def move_base_to(self, col): + self.offset = col + self.max_offset = max(self.max_offset, self.offset) + + def commit(self): + assert not self._committed, "TMEMPool.commit() can only be called once" + from tvm.script import tirx as T + + def emit_alloc(): + _emit_stmt( + T.ptx.tcgen05.alloc( + T.address_of(self.addr), n_cols=self.total_cols, cta_group=self.cta_group + ) + ) + if self.sync_after_alloc: + _emit_stmt(T.cuda.warp_sync()) + + self._emit_warp_guard(self.alloc_warp, emit_alloc) + self._committed = True + + def dealloc(self): + assert self._committed, "TMEMPool.dealloc() called before commit()" + assert not self._deallocated, "TMEMPool.dealloc() can only be called once" + self._deallocated = True + from tvm.script import tirx as T + + def emit_dealloc(): + _emit_stmt(T.ptx.tcgen05.relinquish_alloc_permit(cta_group=self.cta_group)) + _emit_stmt( + T.ptx.tcgen05.dealloc(self.addr, n_cols=self.total_cols, cta_group=self.cta_group) + ) + + self._emit_warp_guard(self.dealloc_warp, emit_dealloc) + + +# --------------------------------------------------------------------------- +# SMEMPool +# --------------------------------------------------------------------------- + + +@_meta_class +class SMEMPool: + """Bump allocator over a contiguous shared memory region. + + Parameters + ---------- + ptr : Var or None, optional + If omitted, an ``alloc_buffer([0], "uint8", scope="shared.dyn")`` is + created automatically and ``commit()`` must be called after all + allocations to emit the size annotation. + If a ``Var`` is provided, the caller manages the backing buffer and + ``commit()`` is a no-op. + """ + + def __init__(self, ptr=_POOL_UNSET): + ir = _get_ir() + if ptr is _POOL_UNSET: + self.buf = ir.alloc_buffer([0], "uint8", scope="shared.dyn") + self.ptr = self.buf.data + self._owns_buffer = True + else: + self.buf = None + self.ptr = ptr + self._owns_buffer = False + self.offset = 0 + self.max_offset = 0 + + def alloc( + self, + shape, + dtype="float32", + strides=None, + scope="shared.dyn", + align=0, + buffer_type="", + axis_separators=None, + layout="default", + ): + ir = _get_ir() + if align > 0: + self.offset = (self.offset + align - 1) // align * align + res = ir.decl_buffer( + shape, + dtype, + data=self.ptr, + strides=strides, + byte_offset=self.offset, + scope=scope, + align=align, + buffer_type=buffer_type, + axis_separators=axis_separators, + layout=layout, + ) + # Advance in bits then round up to bytes so sub-byte dtypes (e.g. + # float4_e2m1fn = 4 bits) still bump the cursor instead of leaving it + # at 0 (bits // 8) and silently overlapping the next allocation. + self.offset += (_shape_product(shape) * DataType(dtype).bits + 7) // 8 + if self._owns_buffer: + self.max_offset = max(self.max_offset, self.offset) + return res + + def alloc_mma(self, shape, dtype="float16", swizzle_mode="auto", align=1024): + """Allocate MMA-compatible shared memory with an inferred swizzle layout.""" + from tvm.backend.cuda.operator.tile_primitive.tma_utils import ( + SwizzleMode, + mma_shared_layout, + ) + + if isinstance(swizzle_mode, str): + if swizzle_mode == "auto": + swizzle_mode = _auto_swizzle_mode(dtype) + elif swizzle_mode == "none": + swizzle_mode = SwizzleMode.SWIZZLE_NONE + else: + raise ValueError( + f"Unsupported swizzle_mode={swizzle_mode!r}; expected 'auto', 'none', " + "or SwizzleMode" + ) + _validate_mma_alloc_shape(shape, dtype, swizzle_mode) + layout = mma_shared_layout(dtype, swizzle_mode, shape) + return self.alloc(shape, dtype, align=align, layout=layout) + + def move_base_to(self, offset): + self.offset = offset + if self._owns_buffer: + self.max_offset = max(self.max_offset, self.offset) + + def commit(self, size=None): + """Emit pool size annotation into the IR. + + Must be called after all ``alloc()`` / ``move_base_to()`` calls. + + Parameters + ---------- + size : int, optional + Explicit shared memory size in bytes. When *None* (the default), + the high-water mark ``max_offset`` tracked by the allocator is used. + """ + if not self._owns_buffer: + return + ir = _get_ir() + frame_mod = _get_frame() + resolved = size if size is not None else self.max_offset + assert resolved >= self.max_offset, ( + f"Specified smem size ({resolved}) is smaller than " + f"the pool high-water mark ({self.max_offset})" + ) + attr_frame = ir.attr(self.ptr, "tirx.pool_max_bytes", resolved) + if isinstance(attr_frame, frame_mod.AttrFrame): + from functools import partial + + attr_frame.add_callback(partial(attr_frame.__exit__, None, None, None)) + attr_frame.__enter__() diff --git a/python/tvm/backend/cuda/lang/pipeline.py b/python/tvm/backend/cuda/lang/pipeline.py new file mode 100644 index 000000000000..ee86090398e9 --- /dev/null +++ b/python/tvm/backend/cuda/lang/pipeline.py @@ -0,0 +1,244 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Reusable pipeline state and mbarrier helpers for SM100 kernels. + +These classes emit TIR via @T.inline. Decorate with @T.meta_class so that +instances are automatically treated as meta values inside @T.prim_func. +""" + +from tvm.script import tirx as T + + +@T.meta_class +class PipelineState: + """Tracks stage and phase for a software-pipelined ring buffer. + + This class does not know anything about full/empty barriers. Use it when + the kernel manually waits/signals barriers, or when the stage/phase drives + a ring not wrapped in a ``Pipeline``. + + Parameters + ---------- + depth : int + Number of stages in the ring. + phase : int, optional + Initial phase. Omit when initialization should happen later. + """ + + def __init__(self, depth: int, phase=None): + self.stage = T.local_scalar("int32") + self.phase = T.local_scalar("int32") + self.depth = depth + if phase is not None: + self.init(phase) + + @T.inline + def init(self, phase): + self.stage = 0 + self.phase = phase + + @T.inline + def advance(self): + if self.depth > 1: + self.stage = self.stage + 1 + if self.stage == self.depth: + self.stage = 0 + self.phase = self.phase ^ 1 + else: + self.phase = self.phase ^ 1 + + +@T.meta_class +class MBarrier: + """Mbarrier wrapper with regular ``mbarrier.arrive``. + + Parameters + ---------- + pool : SMEMPool + Shared memory pool allocator. + depth : int + Number of barrier slots (one per pipeline stage). + phase_offset : int + XORed into the phase bit on every ``wait`` / ``arrive``. + leader : PrimExpr, optional + Boolean predicate selecting the single thread that runs + ``mbarrier.init``. Defaults to ``T.cuda.thread_rank() == 0`` -- + thread 0 of the enclosing CTA, which always picks exactly one + thread regardless of which scope_id vars the caller declared. + Override only when you want a different CTA-local thread to do + the init. + + Note: the default deliberately avoids ``T.warp_id()`` / + ``T.lane_id()``. Those introduce deferred ``cta->warp`` / + ``warp->thread`` ScopeIdDefs that the verifier cannot pin down + unless the kernel header declares the full warp/lane chain (e.g. a + single-CTA DSMEM kernel that only declares ``thread_id``). It also + avoids the synccheck false-deadlock on kernels that declare a + second warp-scope id. The generated CUDA is equivalent. + """ + + def __init__(self, pool, depth, phase_offset=0, leader=None): + self.buf = pool.alloc((depth,), "uint64", align=8) + self.depth = depth + self.phase_offset = phase_offset + self.leader = leader if leader is not None else (T.cuda.thread_rank() == 0) + + @T.inline + def init(self, count): + if self.leader: + for i in T.unroll(self.depth): + T.ptx.mbarrier.init(self.buf.ptr_to([i]), count) + + @T.inline + def wait(self, stage, phase): + # Blocks: ``mbarrier.try_wait`` loops internally until the phase flips, + # so this returns only once the barrier has completed. + T.ptx.mbarrier.try_wait(self.buf.ptr_to([stage]), phase ^ self.phase_offset) + + @T.inline + def arrive(self, stage, cta_id=None, pred=None): + # Default: local-CTA arrive — emits the simple + # ``mbarrier.arrive.shared.b64`` form. To arrive on a remote + # CTA's mbarrier in a cluster kernel, callers must pass + # ``cta_id=`` explicitly (e.g. ``bar.arrive(stage, cta_id=0)``) + # or use ``MBarrier.remote_view(rank).arrive(stage)``. Defaulting + # the cross-CTA path was both surprising (``bar.arrive(stage)`` + # silently ``mapa`` ed across the cluster) and a per-call cost + # of ~3 PTX ops on every single-CTA kernel. + if cta_id is None: + T.ptx.mbarrier.arrive(self.buf.ptr_to([stage])) + else: + actual_pred = True if pred is None else pred + T.ptx.mbarrier.arrive(self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred) + + def ptr_to(self, idx): + return self.buf.ptr_to(idx) + + def remote_view(self, rank): + """Create a view of this barrier mapped to another CTA's shared memory. + + Arrive-only: the returned view is built with ``object.__new__`` and + never copies ``self.leader``, so calling ``.init()`` on it would fail. + Use it solely to ``arrive`` on a remote CTA's mbarrier. + """ + from tvm.ir import PointerType, PrimType + from tvm.tirx import Var as TIRVar + + expr = T.reinterpret("handle", T.ptx.map_shared_rank(self.buf.ptr_to([0]), rank)) + ptr = TIRVar("remote_mbar_ptr", PointerType(PrimType("uint64"))) + T.Bind(expr, var=ptr) + buf = T.decl_buffer([self.depth], "uint64", data=ptr, scope="shared") + remote = object.__new__(type(self)) + remote.buf = buf + remote.depth = self.depth + remote.phase_offset = self.phase_offset + return remote + + +class TMABar(MBarrier): + """Barrier signaled by TMA (mbarrier.arrive.expect_tx). + + When ``tx_count`` is None, falls back to a remote mbarrier.arrive + (matching MBarrier.arrive defaults). + """ + + @T.inline + def arrive(self, stage, tx_count=None, cta_id=None, pred=None): + # NOTE: this arrive() kwarg set intentionally differs from + # MBarrier.arrive (hardware necessity, LSP-incompatible by design). + # ``tx_count``: TMA byte count for ``mbarrier.arrive.expect_tx``. + # ``cta_id`` / ``pred``: forwarded to the underlying + # ``mbarrier.arrive`` (cluster path) when set; otherwise the + # arrive is local-CTA only. See ``MBarrier.arrive`` for the + # full default-local rationale. + if tx_count is not None: + T.ptx.mbarrier.arrive.expect_tx(self.buf.ptr_to([stage]), tx_count) + elif cta_id is None: + T.ptx.mbarrier.arrive(self.buf.ptr_to([stage])) + else: + actual_pred = True if pred is None else pred + T.ptx.mbarrier.arrive(self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred) + + +class TCGen05Bar(MBarrier): + """Barrier signaled by ``tcgen05`` commit. + + The caller is responsible for ensuring only one thread issues the + commit, e.g. by wrapping the call in ``if T.ptx.elect_sync():``. + """ + + @T.inline + def arrive(self, stage, cta_group=1, cta_mask=None): + # NOTE: this arrive() kwarg set intentionally differs from + # MBarrier.arrive (hardware necessity, LSP-incompatible by design). + if cta_mask is None and cta_group == 1: + T.ptx.tcgen05.commit(self.buf.ptr_to([stage])) + else: + T.ptx.tcgen05.commit(self.buf.ptr_to([stage]), cta_group=cta_group, cta_mask=cta_mask) + + +# Barrier-type tags accepted by Pipeline's ``full=`` / ``empty=`` arguments. +_BAR_KINDS = {"tma": TMABar, "tcgen05": TCGen05Bar, "mbar": MBarrier} + + +@T.meta_class +class Pipeline: + """A full/empty mbarrier pair for a software-pipelined data flow. + + Pass barrier-type tags and ``Pipeline`` constructs and ``init``\\ s the + barriers itself. Tags: ``"tma"`` (TMABar), ``"tcgen05"`` (TCGen05Bar), + ``"mbar"`` (MBarrier). The barrier type and arrival count of each event + stay explicit at the call site -- e.g. ``Pipeline(pool, n, full="tma", + empty="tcgen05", init_empty=NUM_CONSUMER)``. + + Both signals are required: a ``Pipeline`` is a *pair*. For a one-way event + (a pure "X happened" signal with no slot to recycle) use a bare barrier + (``TMABar``/``TCGen05Bar``/``MBarrier``) directly -- it has no empty side. + + Parameters + ---------- + pool : SMEMPool + Shared memory pool allocator. + stages : int + Number of pipeline stages (barrier slots). + full, empty : str + Barrier-type tag for the full / empty signal (see above). + init_full, init_empty : int + Expected arrival count for the full / empty barrier. + empty_phase_offset : int + XORed into the empty barrier's phase bit on every wait / arrive. + leader : PrimExpr, optional + Propagated to both barriers; defaults to thread 0 of the CTA. + """ + + def __init__( + self, + pool, + stages, + *, + full, + empty, + init_full=1, + init_empty=1, + empty_phase_offset=0, + leader=None, + ): + self.stages = stages + self.full = _BAR_KINDS[full](pool, stages, leader=leader) + self.full.init(init_full) + self.empty = _BAR_KINDS[empty](pool, stages, phase_offset=empty_phase_offset, leader=leader) + self.empty.init(init_empty) diff --git a/python/tvm/backend/cuda/lang/smem_desc.py b/python/tvm/backend/cuda/lang/smem_desc.py new file mode 100644 index 000000000000..d2561d17dce8 --- /dev/null +++ b/python/tvm/backend/cuda/lang/smem_desc.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""SMEM matrix descriptor helper for tcgen05 / wgmma.""" + +from tvm.backend.cuda.operator.tile_primitive.common import smem_desc_add_16B_offset +from tvm.script import tirx as T + + +@T.meta_class +class SmemDescriptor: + """Encoded once via :meth:`init`, reused via :meth:`add_16B_offset`.""" + + def __init__(self): + self._buf = T.alloc_local([1], "uint64") + + @property + def desc(self): + return self._buf[0] + + @T.inline + def init(self, smem_ptr, ldo, sdo, swizzle): + T.ptx.tcgen05.encode_matrix_descriptor( + T.address_of(self._buf[0]), smem_ptr, ldo, sdo, swizzle + ) + + def add_16B_offset(self, offset): + return smem_desc_add_16B_offset(self._buf[0], offset) + + def make_lo_uniform(self): + """Broadcast the lower 32 bits to all warp lanes via ``__shfl_sync``.""" + func_name = "smem_desc_make_lo_uniform" + source_code = f""" +__forceinline__ __device__ void {func_name}(uint64_t* desc) {{ + SmemDescriptor* d = reinterpret_cast(desc); + d->lo = __shfl_sync(0xffffffff, d->lo, 0); +}} +""" + return T.cuda.func_call( + func_name, T.address_of(self._buf[0]), source_code=source_code, return_type="void" + ) diff --git a/python/tvm/backend/cuda/lang/tile_scheduler.py b/python/tvm/backend/cuda/lang/tile_scheduler.py new file mode 100644 index 000000000000..3fd27f25ee5f --- /dev/null +++ b/python/tvm/backend/cuda/lang/tile_scheduler.py @@ -0,0 +1,816 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Reusable tile scheduler helpers for TIR tests/kernels. + +These classes emit TIR via @T.inline. Decorate with @T.meta_class so that +instances are automatically treated as meta values inside @T.prim_func. +""" + +from tvm.script import tirx as T + + +@T.meta_class +class BaseTileScheduler: + """Base class for tile schedulers with common state and macros.""" + + def __init__(self, prefix: str): + self.m_idx = T.local_scalar("int32") + self.n_idx = T.local_scalar("int32") + self.linear_idx = T.local_scalar("int32") + + @T.inline + def update_current_m_n_idx(self, linear_idx): + # To be implemented by subclasses + pass + + @T.inline + def init(self, linear_init): + self.linear_idx = linear_init + self.update_current_m_n_idx(linear_init) + + @T.inline + def next_tile(self, step): + self.linear_idx = self.linear_idx + step + self.update_current_m_n_idx(self.linear_idx) + + def valid(self, total_tiles): + return self.linear_idx < total_tiles + + +class ClusterPersistentScheduler2D(BaseTileScheduler): + """ + Tile scheduler for cluster-based persistent kernels. + + Distributes a 2D tile grid across persistent clusters using group-major ordering + for L2 cache locality. Each cluster starts at its cluster_id and strides by + num_clusters to process tiles. + + Tile Ordering (group-major for L2 locality): + - Tiles are grouped into "L2 groups" of `l2_group_size` rows + - Within a group, tiles are visited in column-major order within the group + - Groups are processed in row-major order + + Example with 4x4 tiles, l2_group_size=2: + Group 0 (rows 0-1): 0 2 4 6 + 1 3 5 7 + Group 1 (rows 2-3): 8 10 12 14 + 9 11 13 15 + + Serpentine Mode (serpentine=True): + - Uses CUTLASS-style 2D block swizzle with serpentine traversal + - Grid is divided into swizzle_size x swizzle_size blocks + - Within each block, tiles are visited in row-major order + - Blocks are traversed in serpentine order (even block-rows forward, odd backward) + - This provides better L2 locality by reusing both A and B tiles + + Example with 4x4 tiles, swizzle_size=2, serpentine=True: + Block layout: + Block(0,0) Block(0,1) + Block(1,0) Block(1,1) + + Tile numbering with serpentine: + n=0 n=1 n=2 n=3 + m=0 0 1 14 15 + m=1 2 3 12 13 + m=2 4 5 10 11 + m=3 6 7 8 9 + + Traversal: Block(0,0) -> Block(1,0) -> Block(1,1) -> Block(0,1) + (serpentine: down in col 0, then up in col 1) + + Parameters + ---------- + prefix : str + Prefix for TIR variable names + num_m_tiles : int | T.ExprLike + Total number of tiles in M dimension (can be runtime expression) + num_n_tiles : int + Total number of tiles in N dimension + num_clusters : int + Number of persistent clusters (determines stride) + l2_group_size : int + Number of M-tile rows per L2 locality group (default: 8) + When serpentine=True, this is used as swizzle_size for 2D blocks + cluster_m : int + Cluster dimension in M for hierarchical scheduling (default: 1) + cluster_n : int + Cluster dimension in N for hierarchical scheduling (default: 1) + serpentine : bool + If True, use CUTLASS-style 2D block swizzle with serpentine traversal (default: False) + + Attributes + ---------- + m_idx : T.local_scalar + Current M tile index (output) + n_idx : T.local_scalar + Current N tile index (output) + work_idx : T.local_scalar + Global work item index for this cluster + tile_count : T.local_scalar + Number of tiles processed by this cluster so far + + Usage + ----- + ```python + scheduler = ClusterPersistentScheduler2D( + "sched", num_m_tiles=M_TILES, num_n_tiles=N_TILES, + num_clusters=NUM_CLUSTERS, l2_group_size=8 + ) + scheduler.init(cluster_id) # cluster_id = cta_idx // CLUSTER_SIZE + + while scheduler.valid(): + m = T.meta_var(scheduler.m_idx) # current M tile + n = T.meta_var(scheduler.n_idx) # current N tile + # ... process tile (m, n) ... + scheduler.next_tile() + ``` + + Examples + -------- + Example 1: Basic persistent kernel + ``` + num_m_tiles=4, num_n_tiles=4, num_clusters=3, l2_group_size=2 + cluster_m=1, cluster_n=1 (default, no tile subdivision) + + Group-major tile numbering (l2_group_size=2): + n=0 n=1 n=2 n=3 + m=0 0 2 4 6 ┐ L2 group 0 + m=1 1 3 5 7 ┘ + m=2 8 10 12 14 ┐ L2 group 1 + m=3 9 11 13 15 ┘ + + Work distribution (cluster starts at cluster_id, strides by num_clusters=3): + cluster 0: work_idx 0,3,6,9,12,15 -> tiles 0,3,6,9,12,15 + cluster 1: work_idx 1,4,7,10,13 -> tiles 1,4,7,10,13 + cluster 2: work_idx 2,5,8,11,14 -> tiles 2,5,8,11,14 + + Tile grid (which cluster handles each tile): + n=0 n=1 n=2 n=3 + m=0 C0 C2 C1 C0 ┐ L2 group 0 + m=1 C1 C0 C2 C1 ┘ + m=2 C2 C1 C0 C2 ┐ L2 group 1 + m=3 C0 C2 C1 C0 ┘ + + Tile sequence per cluster (in execution order): + cluster 0: (0,0)->(1,1)->(0,3)->(2,0)->(2,3)->(3,3) + cluster 1: (1,0)->(0,2)->(1,3)->(2,1)->(3,2) + cluster 2: (0,1)->(1,2)->(2,0)->(3,1)->(2,3) + ``` + + Example 2: 2SM GEMM (typical B200 config) + ``` + M=1024, N=512, CTA_M=128, MMA_N=128, CLUSTER_M=2, CLUSTER_N=1 + => M_TILES=8, N_TILES=4 + => CLUSTER_M_TILES=4, CLUSTER_N_TILES=4 (scheduler at cluster granularity) + + Scheduler params: + num_m_tiles=4, num_n_tiles=4, num_clusters=74, l2_group_size=8 + cluster_m=1, cluster_n=1 + + Key: Scheduler outputs CLUSTER-level tiles. + All CTAs in same cluster get SAME (m_idx, n_idx) from scheduler. + CTAs differentiate via cluster_rank (computed OUTSIDE scheduler): + cluster_rank = cta_idx % CLUSTER_SIZE + cb_m = cluster_rank % CLUSTER_M # 0 or 1 for 2SM + cb_n = cluster_rank // CLUSTER_M # 0 for 2SM + + Final CTA tile: + cta_m = m_idx * CLUSTER_M + cb_m + cta_n = n_idx * CLUSTER_N + cb_n + + Example: cluster 5 gets scheduler tile (1,2) + CTA rank=0 (cb_m=0): actual tile (2,2) + CTA rank=1 (cb_m=1): actual tile (3,2) + ``` + """ + + def __init__( + self, + prefix: str, + num_m_tiles, + num_n_tiles: int, + num_clusters: int, + l2_group_size: int = 8, + cluster_m: int = 1, + cluster_n: int = 1, + serpentine: bool = False, + ): + super().__init__(prefix) + self._num_m_tiles = num_m_tiles + self._num_n_tiles = num_n_tiles + self._num_clusters = num_clusters + self._l2_group_size = l2_group_size + self._cluster_m = cluster_m + self._cluster_n = cluster_n + self._serpentine = serpentine + + # Rename internal state for clarity + self.work_idx = self.linear_idx # alias: global work item index + self.tile_count = T.local_scalar("int32") + self.tile_idx = self.tile_count # alias for backward compatibility + + is_static_m = isinstance(num_m_tiles, int) + + # Number of tile columns after accounting for cluster_n + n_tile_cols = (num_n_tiles + cluster_n - 1) // cluster_n + self._N_TILE_COLS = n_tile_cols + + if is_static_m: + self._M_TILE_ROWS = (num_m_tiles + cluster_m - 1) // cluster_m + self._FULL_GROUPS = self._M_TILE_ROWS // l2_group_size + else: + # Dynamic expressions for runtime M + self._M_TILE_ROWS = T.truncdiv(self._num_m_tiles + self._cluster_m - 1, self._cluster_m) + self._FULL_GROUPS = T.truncdiv(self._M_TILE_ROWS, self._l2_group_size) + + self._TAIL_ROWS = self._M_TILE_ROWS - self._FULL_GROUPS * l2_group_size + self._TOTAL_TILES = self._M_TILE_ROWS * n_tile_cols * cluster_m * cluster_n + + # For serpentine mode: precompute block counts + if serpentine: + self._N_BLOCKS = n_tile_cols // l2_group_size # full blocks in N + self._M_BLOCKS = ( + self._M_TILE_ROWS // l2_group_size + if is_static_m + else T.truncdiv(self._M_TILE_ROWS, l2_group_size) + ) + self._BLOCK_SIZE = l2_group_size * l2_group_size # tiles per block + self._FULL_BLOCK_TILES = self._M_BLOCKS * self._N_BLOCKS * self._BLOCK_SIZE + # Residual tiles (not covered by full blocks) + self._RESIDUAL_N = n_tile_cols - self._N_BLOCKS * l2_group_size + self._RESIDUAL_M = self._M_TILE_ROWS - self._M_BLOCKS * l2_group_size + + # fmt: off + @T.inline + def update_current_m_n_idx(self, work_idx): + """Convert global work index to (m_idx, n_idx) tile coordinates.""" + CLUSTER_M = T.meta_var(self._cluster_m) + CLUSTER_N = T.meta_var(self._cluster_n) + + # Extract hierarchical cluster-local offsets + cluster_m_offset = T.meta_var(work_idx % CLUSTER_M) + t = T.meta_var(work_idx // CLUSTER_M) + cluster_n_offset = T.meta_var(t % CLUSTER_N) + tile_linear = T.meta_var(t // CLUSTER_N) + + @T.inline + def set_tile_coords(tile_row, tile_col): + self.m_idx = tile_row * CLUSTER_M + cluster_m_offset + self.n_idx = tile_col * CLUSTER_N + cluster_n_offset + + if self._serpentine: + self._update_serpentine(tile_linear, set_tile_coords) + else: + self._update_group_major(tile_linear, set_tile_coords) + + def _update_group_major(self, tile_linear, set_tile_coords): + """Group-major ordering with parse-time pruning of statically-dead branches. + + The TIR script parser does not constant-fold ``if False: ...``, so a + Python-literal ``FULL_GROUPS == 0`` would otherwise produce + ``T.bitwise_and(T.bool(False), tile_linear < 0)`` IR plus the dead + then-leg. Branch in plain Python here and only invoke the inline + emitter that can actually fire. + """ + full_zero = isinstance(self._FULL_GROUPS, int) and self._FULL_GROUPS == 0 + tail_zero = isinstance(self._TAIL_ROWS, int) and self._TAIL_ROWS == 0 + if full_zero and tail_zero: + self._gm_emit_zero(set_tile_coords) + elif full_zero: + self._gm_emit_tail_only(tile_linear, set_tile_coords) + elif tail_zero: + self._gm_emit_full_only(tile_linear, set_tile_coords) + else: + self._gm_emit_full_and_tail(tile_linear, set_tile_coords) + + @T.inline + def _gm_emit_zero(self, set_tile_coords): + set_tile_coords(0, 0) + + @T.inline + def _gm_emit_full_only(self, tile_linear, set_tile_coords): + FULL_GROUPS = T.meta_var(self._FULL_GROUPS) + GROUP_SIZE = T.meta_var(self._l2_group_size) + GROUP_SPAN = T.meta_var(self._l2_group_size * self._N_TILE_COLS) + if (FULL_GROUPS > 0) & (tile_linear < FULL_GROUPS * GROUP_SPAN): + group_id: T.let = tile_linear // GROUP_SPAN + within_group: T.let = tile_linear % GROUP_SPAN + tile_row: T.let = group_id * GROUP_SIZE + (within_group % GROUP_SIZE) + tile_col: T.let = within_group // GROUP_SIZE + set_tile_coords(tile_row, tile_col) + else: + set_tile_coords(0, 0) + + @T.inline + def _gm_emit_tail_only(self, tile_linear, set_tile_coords): + FULL_GROUPS = T.meta_var(self._FULL_GROUPS) + TAIL_ROWS = T.meta_var(self._TAIL_ROWS) + GROUP_SIZE = T.meta_var(self._l2_group_size) + GROUP_SPAN = T.meta_var(self._l2_group_size * self._N_TILE_COLS) + if TAIL_ROWS > 0: + rem: T.let = tile_linear - FULL_GROUPS * GROUP_SPAN + tile_row: T.let = FULL_GROUPS * GROUP_SIZE + (rem % TAIL_ROWS) + tile_col: T.let = rem // TAIL_ROWS + set_tile_coords(tile_row, tile_col) + else: + set_tile_coords(0, 0) + + @T.inline + def _gm_emit_full_and_tail(self, tile_linear, set_tile_coords): + FULL_GROUPS = T.meta_var(self._FULL_GROUPS) + TAIL_ROWS = T.meta_var(self._TAIL_ROWS) + GROUP_SIZE = T.meta_var(self._l2_group_size) + GROUP_SPAN = T.meta_var(self._l2_group_size * self._N_TILE_COLS) + if (FULL_GROUPS > 0) & (tile_linear < FULL_GROUPS * GROUP_SPAN): + group_id: T.let = tile_linear // GROUP_SPAN + within_group: T.let = tile_linear % GROUP_SPAN + tile_row: T.let = group_id * GROUP_SIZE + (within_group % GROUP_SIZE) + tile_col: T.let = within_group // GROUP_SIZE + set_tile_coords(tile_row, tile_col) + elif TAIL_ROWS > 0: + rem: T.let = tile_linear - FULL_GROUPS * GROUP_SPAN + tile_row: T.let = FULL_GROUPS * GROUP_SIZE + (rem % TAIL_ROWS) + tile_col: T.let = rem // TAIL_ROWS + set_tile_coords(tile_row, tile_col) + else: + set_tile_coords(0, 0) + + @T.inline + def _update_serpentine(self, tile_linear, set_tile_coords): + """CUTLASS-style 2D block swizzle with serpentine traversal. + + Algorithm: + 1. Divide grid into swizzle_size x swizzle_size blocks + 2. Within each block, visit tiles in row-major order + 3. Blocks are traversed column by column (along N) + 4. Within each column of blocks, use serpentine: + - Even columns: top to bottom + - Odd columns: bottom to top + + This maximizes L2 reuse for both A and B matrices. + """ + S = T.meta_var(self._l2_group_size) # swizzle_size + M_BLOCKS = T.meta_var(self._M_BLOCKS) + N_BLOCKS = T.meta_var(self._N_BLOCKS) + BLOCK_SIZE = T.meta_var(self._BLOCK_SIZE) # S * S + FULL_BLOCK_TILES = T.meta_var(self._FULL_BLOCK_TILES) + M_TILE_ROWS = T.meta_var(self._M_TILE_ROWS) + T.meta_var(self._N_TILE_COLS) + RESIDUAL_N = T.meta_var(self._RESIDUAL_N) + RESIDUAL_M = T.meta_var(self._RESIDUAL_M) + + # Check if we're in the full block region + if (M_BLOCKS > 0) & (N_BLOCKS > 0) & (tile_linear < FULL_BLOCK_TILES): + # Which block (in linear order along columns of blocks) + block_linear: T.let = tile_linear // BLOCK_SIZE + within_block: T.let = tile_linear % BLOCK_SIZE + + # Block column and row + block_col: T.let = block_linear // M_BLOCKS + block_row_raw: T.let = block_linear % M_BLOCKS + + # Serpentine: odd columns go bottom-to-top + block_row: T.let = T.Select( + block_col % 2 == 0, + block_row_raw, + M_BLOCKS - 1 - block_row_raw + ) + + # Position within block (row-major within block) + local_row: T.let = within_block // S + local_col: T.let = within_block % S + + tile_row: T.let = block_row * S + local_row + tile_col: T.let = block_col * S + local_col + set_tile_coords(tile_row, tile_col) + + elif RESIDUAL_N > 0: + # Residual tiles in the rightmost partial column of blocks + # These are tiles where n >= N_BLOCKS * S + rem: T.let = tile_linear - FULL_BLOCK_TILES + + # First handle the right residual strip (full M height, partial N width) + right_strip_tiles: T.let = M_TILE_ROWS * RESIDUAL_N + if rem < right_strip_tiles: + # Row-major within the right strip + tile_row: T.let = rem // RESIDUAL_N + tile_col: T.let = N_BLOCKS * S + (rem % RESIDUAL_N) + set_tile_coords(tile_row, tile_col) + elif RESIDUAL_M > 0: + # Bottom residual strip (already covered in right strip overlap) + # This handles corner case - shouldn't normally reach here + # as right strip already covers full M height + set_tile_coords(0, 0) + else: + set_tile_coords(0, 0) + + elif RESIDUAL_M > 0: + # Bottom residual strip only (no right residual) + rem: T.let = tile_linear - FULL_BLOCK_TILES + bottom_strip_tiles: T.let = RESIDUAL_M * (N_BLOCKS * S) + if rem < bottom_strip_tiles: + tile_row: T.let = M_BLOCKS * S + (rem % RESIDUAL_M) + tile_col: T.let = rem // RESIDUAL_M + set_tile_coords(tile_row, tile_col) + else: + set_tile_coords(0, 0) + else: + # Fallback + set_tile_coords(0, 0) + + @T.inline + def init(self, cluster_id): + """Initialize scheduler for a given cluster. + + Parameters + ---------- + cluster_id : int + The cluster's index (typically cta_idx // CLUSTER_SIZE) + """ + self.linear_idx = cluster_id + self.tile_count = 0 + self.update_current_m_n_idx(cluster_id) + + @T.inline + def next_tile(self): + """Advance to the next tile for this cluster.""" + self.linear_idx = self.linear_idx + self._num_clusters + self.tile_count = self.tile_count + 1 + self.update_current_m_n_idx(self.linear_idx) + + @T.inline + def next_tile_stride(self, stride: int): + """Advance by a custom stride (for non-standard scheduling).""" + self.linear_idx = self.linear_idx + stride + self.tile_count = self.tile_count + 1 + self.update_current_m_n_idx(self.linear_idx) + # fmt: on + + def valid(self): + """Check if this cluster has more tiles to process.""" + return self.linear_idx < self._TOTAL_TILES + + +class GroupMajor3D(BaseTileScheduler): + """ + 3D grouped-row scheduler (M,N,K) with tail handling on M. + + Args + ---- + prefix: str + m_tiles: int | T PrimExpr # tiles along M (static or runtime) + n_tiles: int # tiles along N (static) + k_tiles: int # tiles along K (static) + group_rows: int # rows per group along M + step: int = 1 # default stride for next_tile() + """ + + def __init__( + self, prefix: str, m_tiles, n_tiles: int, k_tiles: int, group_rows: int, step: int = 1 + ): + super().__init__(prefix) + self._step = step + self.tile_idx = T.local_scalar("int32") + self.k_idx = T.local_scalar("int32") + + # ---- constants / primexprs baked once ---- + self._G = group_rows + self._N = n_tiles + self._K = k_tiles + + if isinstance(m_tiles, int): + self._GROUPS = m_tiles // group_rows + self._FINAL_ROWS = m_tiles - self._GROUPS * group_rows + self._SAFE_FINAL_ROWS = max(self._FINAL_ROWS, 1) + self._GROUP_SIZE = group_rows * n_tiles * k_tiles + self._TOTAL = m_tiles * n_tiles * k_tiles + else: + self._GROUPS = T.truncdiv(m_tiles, group_rows) + self._FINAL_ROWS = m_tiles - self._GROUPS * group_rows + self._SAFE_FINAL_ROWS = T.max(self._FINAL_ROWS, 1) + self._GROUP_SIZE = self._G * self._N * self._K + self._TOTAL = m_tiles * n_tiles * k_tiles + + # handy composites used in macro + self._FULL_BOUND = self._GROUPS * self._GROUP_SIZE + self._HAS_FULL = self._GROUPS > 0 + self._HAS_TAIL = self._FINAL_ROWS > 0 + + # fmt: off + @T.inline + def update_current_m_n_idx(self, linear_idx): + # full-group formulas + full_m: T.let = T.floordiv(linear_idx, self._GROUP_SIZE) * self._G + T.floormod( + linear_idx, self._G + ) + full_n: T.let = T.floormod(T.floordiv(linear_idx, self._G), self._N) + full_k: T.let = T.floordiv(T.floormod(linear_idx, self._GROUP_SIZE), self._G * self._N) + + # tail formulas (relative to FULL_BOUND) + # Use _SAFE_FINAL_ROWS (max(FINAL_ROWS, 1)) to avoid divide-by-zero when there is no tail + rem: T.let = linear_idx - self._FULL_BOUND + tail_m: T.let = self._GROUPS * self._G + T.floormod(rem, self._SAFE_FINAL_ROWS) + tail_n: T.let = T.floordiv(rem, self._SAFE_FINAL_ROWS) % self._N + tail_k: T.let = T.floordiv(rem, self._SAFE_FINAL_ROWS * self._N) + + # choose phase + if self._HAS_FULL & (linear_idx < self._FULL_BOUND): + self.m_idx = full_m + self.n_idx = full_n + self.k_idx = full_k + elif self._HAS_TAIL: + self.m_idx = tail_m + self.n_idx = tail_n + self.k_idx = tail_k + else: + self.m_idx = 0 + self.n_idx = 0 + self.k_idx = 0 + + @T.inline + def init(self, linear_init): + self.linear_idx = linear_init + self.tile_idx = 0 + self.update_current_m_n_idx(linear_init) + + @T.inline + def next_tile(self): + self.linear_idx = self.linear_idx + self._step + self.tile_idx = self.tile_idx + 1 + self.update_current_m_n_idx(self.linear_idx) + + @T.inline + def next_tile_stride(self, stride: int): + self.linear_idx = self.linear_idx + stride + self.tile_idx = self.tile_idx + 1 + self.update_current_m_n_idx(self.linear_idx) + # fmt: on + + def valid(self): + return self.linear_idx < self._TOTAL + + +class RankAwareGroupMajorTileScheduler(BaseTileScheduler): + """ + Group-major scheduler that applies a rank-aware remapping (remote rows first). + Kept as a thin adapter because it depends on NVSHMEM rank at device-side. + """ + + def __init__( + self, prefix: str, m_clusters: int, n_clusters: int, group_size: int, world_size: int + ): + super().__init__(prefix) + self._m_clusters = m_clusters + self._n_clusters = n_clusters + self._group_size = group_size + self._world_size = world_size + + @T.inline + def update_current_m_n_idx(self, linear_idx): + my_rank: T.let = T.nvshmem.my_pe() + remote_m_clusters: T.let = self._m_clusters - self._m_clusters // self._world_size + group_rows: T.let = (remote_m_clusters // self._group_size) * self._group_size + final_rows: T.let = remote_m_clusters - group_rows + group_repeat: T.let = self._group_size * self._n_clusters + if linear_idx < group_rows * self._n_clusters and group_rows > 0: + self.m_idx = ( + (linear_idx // group_repeat) * self._group_size + + (linear_idx % self._group_size) + + (my_rank + 1) * self._m_clusters // self._world_size + ) % self._m_clusters + self.n_idx = (linear_idx % group_repeat) // self._group_size + elif linear_idx < remote_m_clusters * self._n_clusters: + remainder_idx: T.let = linear_idx - group_rows * self._n_clusters + self.m_idx = ( + group_rows + + remainder_idx % final_rows + + (my_rank + 1) * self._m_clusters // self._world_size + ) % self._m_clusters + self.n_idx = remainder_idx // final_rows + else: + remainder_idx: T.let = linear_idx - remote_m_clusters * self._n_clusters + self.m_idx = ( + remote_m_clusters + + remainder_idx % (self._m_clusters // self._world_size) + + (my_rank + 1) * self._m_clusters // self._world_size + ) % self._m_clusters + self.n_idx = remainder_idx // (self._m_clusters // self._world_size) + + @T.inline + def next_tile(self, stride: int): + self.linear_idx = self.linear_idx + stride + self.update_current_m_n_idx(self.linear_idx) + + def valid(self): + return self.linear_idx < self._m_clusters * self._n_clusters + + +class IndexedTripleTileScheduler(BaseTileScheduler): + """Scheduler that maps linear_idx to (b_idx, h_idx, q_idx) via index lists.""" + + def __init__(self, prefix: str, b_indices, h_indices, q_indices, tiles_indptr): + super().__init__(prefix) + self.b_indices = b_indices + self.h_indices = h_indices + self.q_indices = q_indices + self.tiles_indptr = tiles_indptr + self.q_idx = T.local_scalar("int32") + self.h_idx = T.local_scalar("int32") + self.b_idx = T.local_scalar("int32") + self.linear_lim = T.local_scalar("int32") + + @T.inline + def _load(self): + self.q_idx = self.q_indices[self.linear_idx] + self.h_idx = self.h_indices[self.linear_idx] + self.b_idx = self.b_indices[self.linear_idx] + + @T.inline + def init(self, sm): + self.linear_idx = self.tiles_indptr[sm] + self.linear_lim = self.tiles_indptr[sm + 1] + self._load() + + @T.inline + def next_tile(self): + self.linear_idx = self.linear_idx + 1 + self._load() + + def valid(self): + return self.linear_idx < self.linear_lim + + +class FlashAttentionLinearScheduler(BaseTileScheduler): + """Linear 3D scheduler for flash attention (batch, head, m_block). + + Used for non-causal attention with simple linear decomposition. + Maps linear_idx -> (batch_idx, head_idx, m_block_idx) using: + batch = linear_idx // (num_heads * num_m_blocks) + head = (linear_idx % (num_heads * num_m_blocks)) // num_m_blocks + m_block = linear_idx % num_m_blocks + + Parameters + ---------- + prefix : str + Prefix for TIR variable names + num_batches : int + Number of batches + num_heads : int + Number of KV heads + num_m_blocks : int + Number of Q blocks (M dimension tiles) + num_ctas : int + Number of CTAs for persistent kernel stride + """ + + def __init__( + self, prefix: str, num_batches: int, num_heads: int, num_m_blocks: int, num_ctas: int + ): + super().__init__(prefix) + self._num_batches = num_batches + self._num_heads = num_heads + self._num_m_blocks = num_m_blocks + self._num_ctas = num_ctas + self._total_tasks = num_batches * num_heads * num_m_blocks + + # Output indices + self.batch_idx = T.local_scalar("int32") + self.head_idx = T.local_scalar("int32") + self.m_block_idx = T.local_scalar("int32") + + # fmt: off + @T.inline + def update_current_m_n_idx(self, linear_idx): + """Convert linear index to (batch, head, m_block) coordinates.""" + NUM_HEADS = T.meta_var(self._num_heads) + NUM_M_BLOCKS = T.meta_var(self._num_m_blocks) + HEAD_M_PRODUCT = T.meta_var(NUM_HEADS * NUM_M_BLOCKS) + + self.batch_idx = linear_idx // HEAD_M_PRODUCT + self.head_idx = (linear_idx % HEAD_M_PRODUCT) // NUM_M_BLOCKS + self.m_block_idx = linear_idx % NUM_M_BLOCKS + + @T.inline + def init(self, cta_id): + """Initialize scheduler with CTA ID.""" + self.linear_idx = cta_id + self.update_current_m_n_idx(cta_id) + + @T.inline + def next_tile(self): + """Advance to next tile by striding by num_ctas.""" + self.linear_idx = self.linear_idx + self._num_ctas + self.update_current_m_n_idx(self.linear_idx) + # fmt: on + + def valid(self): + """Check if there are more tiles to process.""" + return self.linear_idx < self._total_tasks + + +class FlashAttentionLPTScheduler(BaseTileScheduler): + """LPT scheduler with L2 swizzle for causal flash attention. + + Processes high-work Q blocks (with more KV blocks to attend to) first using + Longest Processing Time (LPT) scheduling. Also applies L2 cache swizzle + for better cache locality across batch*head dimensions. + + The LPT aspect comes from reversing m_block order: lower Q blocks have more + KV blocks to process due to causal masking, so processing them first balances load. + + The scheduler is only applied to non-persistent kernels. + + L2 Swizzle: Groups consecutive batch*head indices together for L2 locality. + + Parameters + ---------- + prefix : str + Prefix for TIR variable names + num_batches : int + Number of batches + num_heads : int + Number of KV heads + num_m_blocks : int + Number of Q blocks (M dimension tiles) + num_ctas : int + Number of CTAs (should equal total_tasks for causal) + l2_swizzle : int + L2 swizzle factor for cache locality + """ + + def __init__( + self, prefix: str, num_batches: int, num_heads: int, num_m_blocks: int, l2_swizzle: int + ): + super().__init__(prefix) + self._num_batches = num_batches + self._num_heads = num_heads + self._num_m_blocks = num_m_blocks + self._l2_swizzle = l2_swizzle + self._total_tasks = num_batches * num_heads * num_m_blocks + + # Derived constants for L2 swizzle + self._num_hb = num_batches * num_heads + self._l2_major = l2_swizzle * num_m_blocks + self._num_hb_quotient = self._num_hb // l2_swizzle + + # Output indices + self.batch_idx = T.local_scalar("int32") + self.head_idx = T.local_scalar("int32") + self.m_block_idx = T.local_scalar("int32") + + # fmt: off + @T.inline + def update_current_m_n_idx(self, linear_idx): + """Convert linear index to (batch, head, m_block) with LPT + L2 swizzle.""" + L2_SWIZZLE = T.meta_var(self._l2_swizzle) + L2_MAJOR = T.meta_var(self._l2_major) + NUM_HB_QUOTIENT = T.meta_var(self._num_hb_quotient) + NUM_HB = T.meta_var(self._num_hb) + NUM_HEADS = T.meta_var(self._num_heads) + NUM_M_BLOCKS = T.meta_var(self._num_m_blocks) + + # L2 swizzle decomposition + bidhb: T.let = linear_idx // L2_MAJOR + l2_mod: T.let = linear_idx % L2_MAJOR + + # Handle residual section (last partial swizzle group) + num_hb_remainder: T.let = T.max(NUM_HB % L2_SWIZZLE, 1) + m_block_raw: T.let = T.Select(bidhb < NUM_HB_QUOTIENT, l2_mod // L2_SWIZZLE, l2_mod // num_hb_remainder) # noqa: E501 + bidhb_residual: T.let = T.Select(bidhb < NUM_HB_QUOTIENT, l2_mod % L2_SWIZZLE, l2_mod % num_hb_remainder) # noqa: E501 + bidhb_actual: T.let = bidhb * L2_SWIZZLE + bidhb_residual + + self.batch_idx = bidhb_actual // NUM_HEADS + self.head_idx = bidhb_actual % NUM_HEADS + + # LPT: Reverse block order so high-work blocks are processed first + self.m_block_idx = (NUM_M_BLOCKS - 1) - m_block_raw + + @T.inline + def init(self, cta_id): + """Initialize scheduler with CTA ID.""" + self.linear_idx = cta_id + self.update_current_m_n_idx(cta_id) + + @T.inline + def next_tile(self): + """Advance to next tile by striding by num_ctas.""" + self.linear_idx = self._total_tasks + # fmt: on + + def valid(self): + """Check if there are more tiles to process.""" + return self.linear_idx < self._total_tasks diff --git a/python/tvm/backend/cuda/lang/warp_role.py b/python/tvm/backend/cuda/lang/warp_role.py new file mode 100644 index 000000000000..0258013bab1a --- /dev/null +++ b/python/tvm/backend/cuda/lang/warp_role.py @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Warp role helpers for SM100 kernels. + +Simplifies the common pattern of dispatching warps to named roles +with register budgets. + +Example:: + + # Declare roles + tma_warp = WarpRole(warp_id, 1, regs=48) + store_warp = WarpRole(warp_id, 2, regs=48) + mma_warp = WarpRole(warp_id, 0, regs=232, increase=True) + + # Use with context manager + with tma_warp: + # TMA load code + with store_warp: + # TMA store code + with mma_warp: + # MMA compute code +""" + +from tvm.script import tirx as T + + +class WarpRole: + """A warp-level role that guards a block of code by warp_id comparison + with optional register budget. + + Generates:: + + if == : + T.ptx.setmaxnreg(, ) # if regs specified + + + The ``if`` guard narrows the active set to the single warp; individual + tile-primitive calls inside ```` carry their own exec scope via + a scope-namespace prefix (e.g. ``Tx.warp.copy(...)``). + + Parameters + ---------- + warp_id_var : Var + The warp_id variable (from ``T.warp_id(...)``). + warp_id_val : int + Which warp index this role corresponds to. + regs : int, optional + Register budget (passed to ``T.ptx.setmaxnreg``). + If None, no setmaxnreg is emitted. + increase : bool + Direction for ``setmaxnreg`` (default False = decrease). + """ + + def __init__(self, warp_id_var, warp_id_val, regs=None, increase=False): + self.warp_id_var = warp_id_var + self.warp_id_val = warp_id_val + self.regs = regs + self.increase = increase + + def __enter__(self): + self._if_frame = T.If(self.warp_id_var == self.warp_id_val) + self._if_frame.__enter__() + self._then_frame = T.Then() + self._then_frame.__enter__() + if self.regs is not None: + T.evaluate(T.ptx.setmaxnreg(self.increase, self.regs)) + return self + + def __exit__(self, *exc): + self._then_frame.__exit__(*exc) + self._if_frame.__exit__(*exc) + return False + + +class WarpgroupRole: + """A warpgroup-level role that guards by wg_id comparison, + with optional register budget. + + Generates (single wg_id):: + + if == : + T.ptx.setmaxnreg(, ) # if regs specified + + + Generates (range of wg_ids, e.g. ``wg_id_val=(0, 2)``):: + + if 0 <= and < 2: + T.ptx.setmaxnreg(, ) + + + The ``if`` guard narrows the active set to the target warpgroup(s); + individual tile-primitive calls inside ```` carry their own exec + scope via a scope-namespace prefix (e.g. ``Tx.wg.copy(...)``). + + Parameters + ---------- + wg_id_var : Var + The warpgroup_id variable (from ``T.warpgroup_id(...)``). + wg_id_val : int or tuple[int, int] + Which warpgroup index (int) or range ``(start, stop)`` this role + corresponds to. + regs : int, optional + Register budget. + increase : bool + Direction for ``setmaxnreg`` (default False = decrease). + """ + + def __init__(self, wg_id_var, wg_id_val, regs=None, increase=False): + self.wg_id_var = wg_id_var + self.wg_id_val = wg_id_val + self.regs = regs + self.increase = increase + + def __enter__(self): + if isinstance(self.wg_id_val, tuple): + start, stop = self.wg_id_val + self._if_frame = T.If(start <= self.wg_id_var and self.wg_id_var < stop) + else: + self._if_frame = T.If(self.wg_id_var == self.wg_id_val) + self._if_frame.__enter__() + self._then_frame = T.Then() + self._then_frame.__enter__() + if self.regs is not None: + T.evaluate(T.ptx.setmaxnreg(self.increase, self.regs)) + return self + + def __exit__(self, *exc): + self._then_frame.__exit__(*exc) + self._if_frame.__exit__(*exc) + return False diff --git a/python/tvm/backend/cuda/op.py b/python/tvm/backend/cuda/op.py new file mode 100644 index 000000000000..e76d5fbe2452 --- /dev/null +++ b/python/tvm/backend/cuda/op.py @@ -0,0 +1,4283 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, too-many-arguments +"""CUDA, PTX, and NVSHMEM TIR intrinsic builders.""" + +from __future__ import annotations + +from tvm import tirx +from tvm.ir import Op, PrimExpr +from tvm.runtime import const +from tvm.tirx.expr import Call +from tvm.tirx.op import bitwise_and, call_intrin, tvm_access_ptr +from tvm.tirx.operator.intrinsics._common import CLUSTER_BARRIER_SEM as _CLUSTER_BARRIER_SEM +from tvm.tirx.operator.intrinsics._common import ( + CP_ASYNC_BULK_CACHE_HINT as _CP_ASYNC_BULK_CACHE_HINT, +) +from tvm.tirx.operator.intrinsics._common import CP_ASYNC_BULK_RED_OP as _CP_ASYNC_BULK_RED_OP +from tvm.tirx.operator.intrinsics._common import CP_ASYNC_CACHE_HINT as _CP_ASYNC_CACHE_HINT +from tvm.tirx.operator.intrinsics._common import CP_ASYNC_FILL_MODE as _CP_ASYNC_FILL_MODE +from tvm.tirx.operator.intrinsics._common import CP_ASYNC_PREFETCH_SIZE as _CP_ASYNC_PREFETCH_SIZE +from tvm.tirx.operator.intrinsics._common import F32X2_ROUND as _F32X2_ROUND +from tvm.tirx.operator.intrinsics._common import FENCE_PROXY_ASYNC_SPACE as _FENCE_PROXY_ASYNC_SPACE +from tvm.tirx.operator.intrinsics._common import FENCE_SCOPE as _FENCE_SCOPE +from tvm.tirx.operator.intrinsics._common import FENCE_SEM as _FENCE_SEM +from tvm.tirx.operator.intrinsics._common import LDMATRIX_DTYPE as _LDMATRIX_DTYPE +from tvm.tirx.operator.intrinsics._common import LDMATRIX_NUM as _LDMATRIX_NUM +from tvm.tirx.operator.intrinsics._common import NVSHMEM_CMP as _NVSHMEM_CMP +from tvm.tirx.operator.intrinsics._common import NVSHMEM_SIG_OP as _NVSHMEM_SIG_OP +from tvm.tirx.operator.intrinsics._common import TCGEN05_CP_DECOMPRESS as _TCGEN05_CP_DECOMPRESS +from tvm.tirx.operator.intrinsics._common import TCGEN05_CP_MULTICAST as _TCGEN05_CP_MULTICAST +from tvm.tirx.operator.intrinsics._common import TCGEN05_CP_SHAPES as _TCGEN05_CP_SHAPES +from tvm.tirx.operator.intrinsics._common import TCGEN05_CTA_GROUP as _TCGEN05_CTA_GROUP +from tvm.tirx.operator.intrinsics._common import TCGEN05_LDST_SHAPES as _TCGEN05_LDST_SHAPES + +tir = tirx + +######################################################## +# CUDA native builtins +######################################################## + + +def cuda_func_call(func_name, *args, source_code, return_type="void"): + """TVM intrinsic to call a CUDA function. Source code is provided as a string. + + Parameters + ---------- + func_name: str + The name of the CUDA function. + + args: PrimExpr + The arguments to the CUDA function. + + source_code: str + The source code of the CUDA function. + + return_type: str + The return type of the CUDA function. + """ + return call_intrin(return_type, "tirx.cuda_func_call", func_name, *args, source_code) + + +def cuda_warp_reduce(value, op, width=32): + """Warp-level butterfly shuffle-XOR reduction. + + Reduces ``value`` across ``width`` adjacent lanes using the specified + operation. Codegen emits ``log2(width)`` steps of + ``__shfl_xor_sync(0xFFFFFFFF, val, mask)`` with descending XOR masks. + + Parameters + ---------- + value : PrimExpr + The per-thread scalar value to reduce. + + op : str + Reduction operation: ``"sum"``, ``"max"``, or ``"min"``. + + width : int + Number of lanes participating in each reduction group. + Must be a power of two in [2, 32]. Defaults to 32 (full warp). + + Returns + ------- + call : PrimExpr + The reduced value (same dtype as *value*). + """ + return call_intrin(value.dtype, "tirx.cuda_warp_reduce", value, op, width) + + +def cuda_warp_sum(value, width=32): + """Convenience wrapper: ``cuda_warp_reduce(value, "sum", width)``.""" + return cuda_warp_reduce(value, "sum", width) + + +def cuda_warp_max(value, width=32): + """Convenience wrapper: ``cuda_warp_reduce(value, "max", width)``.""" + return cuda_warp_reduce(value, "max", width) + + +def cuda_warp_min(value, width=32): + """Convenience wrapper: ``cuda_warp_reduce(value, "min", width)``.""" + return cuda_warp_reduce(value, "min", width) + + +def cuda_cta_reduce(value, op, num_warps, scratch): + """CTA-wide reduction via warp shuffle + shared memory. + + Two-step reduction: (1) intra-warp shuffle reduction, (2) warp-0 + collects per-warp partials from ``scratch``, reduces, broadcasts via + ``__syncthreads()``. All CTA threads must participate. + + Parameters + ---------- + value : PrimExpr + Per-thread scalar value to reduce. + + op : str + Reduction operation: ``"sum"``, ``"max"``, or ``"min"``. + + num_warps : int + Number of warps in the CTA. Must be a power of two in [1, 32]. + + scratch : Var + Data pointer to shared-memory scratch space (>= num_warps elements). + + Returns + ------- + call : PrimExpr + The reduced value broadcast to all threads (same dtype as *value*). + """ + return call_intrin(value.dtype, "tirx.cuda_cta_reduce", value, op, num_warps, scratch) + + +def cuda_cta_sum(value, num_warps, scratch): + """Convenience wrapper: ``cuda_cta_reduce(value, "sum", num_warps, scratch)``.""" + return cuda_cta_reduce(value, "sum", num_warps, scratch) + + +def cuda_cta_max(value, num_warps, scratch): + """Convenience wrapper: ``cuda_cta_reduce(value, "max", num_warps, scratch)``.""" + return cuda_cta_reduce(value, "max", num_warps, scratch) + + +def cuda_cta_min(value, num_warps, scratch): + """Convenience wrapper: ``cuda_cta_reduce(value, "min", num_warps, scratch)``.""" + return cuda_cta_reduce(value, "min", num_warps, scratch) + + +def cuda_copy_bytes(dst, src, num_bytes): + """Typed load/store copy of ``num_bytes`` bytes. + + Copies ``num_bytes`` bytes from ``src`` to ``dst`` using a single + typed load/store instruction. Codegen selects the appropriate C++ + vector type (``uint4``, ``uint2``, ``unsigned int``, etc.). + + Parameters + ---------- + dst : Var + Destination pointer. + + src : Var + Source pointer. + + num_bytes : int + Number of bytes to copy. Must be one of {1, 2, 4, 8, 16}. + + Returns + ------- + call : PrimExpr + A void call expression. + """ + return call_intrin("void", "tirx.cuda_copy_bytes", dst, src, num_bytes) + + +def cuda_copy_128b(dst, src): + """Convenience wrapper: ``cuda_copy_bytes(dst, src, 16)`` — copies 128 bits.""" + return cuda_copy_bytes(dst, src, 16) + + +def cuda_copy_64b(dst, src): + """Convenience wrapper: ``cuda_copy_bytes(dst, src, 8)`` — copies 64 bits.""" + return cuda_copy_bytes(dst, src, 8) + + +def cuda_copy_32b(dst, src): + """Convenience wrapper: ``cuda_copy_bytes(dst, src, 4)`` — copies 32 bits.""" + return cuda_copy_bytes(dst, src, 4) + + +def cuda_copy_16b(dst, src): + """Convenience wrapper: ``cuda_copy_bytes(dst, src, 2)`` — copies 16 bits.""" + return cuda_copy_bytes(dst, src, 2) + + +def cuda_copy_8b(dst, src): + """Convenience wrapper: ``cuda_copy_bytes(dst, src, 1)`` — copies 8 bits.""" + return cuda_copy_bytes(dst, src, 1) + + +def cuda_warp_sync(): + """TVM intrinsic to synchronize threads within the current warp. + + This lowers to a CUDA `__syncwarp()` call. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.cuda_warp_sync") + + +def cuda_cta_sync(): + """TVM intrinsic to call CUDA syncthreads (block-wide barrier) + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.cuda_cta_sync") + + +def cuda_grid_sync(): + """TVM intrinsic to call CUDA grid-wide sync (cooperative groups) + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.cuda_grid_sync") + + +def cuda_cluster_sync(): + """TVM intrinsic to call CUDA cluster-wide barrier sync + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.cuda_cluster_sync") + + +def cuda_thread_rank(): + """TVM intrinsic that returns ``cooperative_groups::thread_rank()`` + for the enclosing CTA -- the linear thread index within the block. + + Useful for building "single thread of CTA" predicates without + referencing user-declared scope_id vars. For example, the idiomatic + mbarrier.init leader predicate is:: + + T.cuda.thread_rank() == 0 + + Returns + ------- + call : PrimExpr + The call expression (``int32``). + """ + return call_intrin("int32", "tirx.cuda_thread_rank") + + +def cuda_half2float(src): + """TVM intrinsic to convert half to float + + Parameters + ---------- + src : PrimExpr + Source pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("float32", "tirx.cuda_half2float", src) + + +def cuda_bfloat162float(src): + """TVM intrinsic to convert bfloat16 to float + + Parameters + ---------- + src : PrimExpr + Source pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("float32", "tirx.cuda_bfloat162float", src) + + +def cuda_float22half2(dst, src): + """TVM intrinsic to convert float2 to half2 with rounding + + Parameters + ---------- + dst : PrimExpr + Destination pointer. + + src : PrimExpr + Source pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.cuda_float22half2", dst, src) + + +def cuda_trap_when_assert_failed(cond): + """TVM intrinsic to trap when assertion failed (cond == false) + + Parameters + ---------- + cond : PrimExpr + Condition to check. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.cuda_trap_when_assert_failed", cond) + + +def cuda_runtime_instr_desc(desc, sf_id): + """TVM intrinsic to update runtime instruction descriptor + + Parameters + ---------- + desc : PrimExpr + Pointer to the descriptor (uint32*). + + sf_id : PrimExpr + The subfragment id. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.cuda_runtime_instr_desc", desc, sf_id) + + +def cuda_half8tofloat8(src_addr, dst_addr): + """TVM intrinsic to convert 8 half2s to 8 float2s + + Parameters + ---------- + src_addr : PrimExpr + Source pointer. + + dst_addr : PrimExpr + Destination pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.cuda_half8tofloat8", src_addr, dst_addr) + + +def cuda_float8tohalf8(src_addr, dst_addr): + """TVM intrinsic to convert 8 float2s to 8 half2s + + Parameters + ---------- + src_addr : PrimExpr + Source pointer. + + dst_addr : PrimExpr + Destination pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.cuda_float8tohalf8", src_addr, dst_addr) + + +def ptx_mma_sp( + dtype, + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, + metadata, + meta_index, + sparse_selector, + saturate, +): + """TVM intrinsic for sparse tensor core ptx instructions + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma + + Parameters + ---------- + dtype : str + The data type of the result. + + shape : str + The shape of mma fragment. + + A_layout : Literal["row", "col"] + The layout of multiplicand fragment A. + + B_layout : Literal["row", "col"] + The layout of multiplicand fragment B. + + A_dtype : str + The data type of multiplicand fragment A. + + B_dtype : str + The data type of multiplicand fragment B. + + C_dtype : str + The data type of multiplicand fragment C. + + multiplicand_a : Var + The multiplicand fragment A variable. + + a_index : Expr + The index of multiplicand fragment A. + + multiplicand_b : Var + The multiplicand fragment B variable. + + b_index : Expr + The index of multiplicand fragment B. + + accumulator : Var + The accumulator fragment C variable. + + c_index : Expr + The index of accumulator fragment C. + + metadata : Expr + The metadata of operand. + + meta_index : Expr + The metadata index of operand. + + sparse_selector : Expr + The sparse selector indicating the thread that stores the metadata. + + saturate : bool + The optional saturation at the output. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + dtype, + "tirx.ptx_mma_sp", + shape, + A_layout, + B_layout, + A_dtype, + B_dtype, + C_dtype, + multiplicand_a, + a_index, + multiplicand_b, + b_index, + accumulator, + c_index, + metadata, + meta_index, + sparse_selector, + saturate, + ) + + +def ptx_cp_async_bulk( + dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id +): + """TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk + + Parameters + ---------- + dtype : str + The data type of the result. + + shared_ptr : Var + The shared memory pointer variable. + + shared_offset : Expr + The offset of shared memory pointer. + + global_ptr : Var + The global memory pointer variable. + + global_offset : Expr + The offset of global memory pointer. + + bytes : int + The data size to copy. + + barrier_id : int + The ID of the barrier shared memory pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin( + dtype, + "tirx.ptx_cp_async_bulk", + shared_ptr, + shared_offset, + global_ptr, + global_offset, + bytes, + barrier_id, + ) + + +def ptx_cp_async_bulk_shared_to_cluster(dst_ptr, src_ptr, size, mbar): + """PTX cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes + + Asynchronous bulk copy from executing CTA's shared memory to a remote + CTA's shared memory within the same cluster. + + Parameters + ---------- + dst_ptr : PrimExpr + Destination pointer in shared::cluster address space (remote CTA). + + src_ptr : PrimExpr + Source pointer in shared::cta address space (local CTA). + + size : PrimExpr + Number of bytes to copy (must be multiple of 16). + + mbar : PrimExpr + Mbarrier address in shared::cluster space for completion signaling, + usually produced by ``T.ptx.map_shared_rank``. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_cp_async_bulk_shared_to_cluster", dst_ptr, src_ptr, size, mbar) + + +def ptx_cp_async_mbarrier_arrive(barrier_id): + """TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive + + Parameters + ---------- + barrier_id : int + The ID of the barrier shared memory pointer. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_cp_async_mbarrier_arrive", barrier_id) + + +def ptx_fence(sem: str, scope: str): + """TVM intrinsic for PTX fence instruction. + + Generates: fence.{sem}.{scope}; + + Parameters + ---------- + sem : str + The semantics of the fence. One of "sc", "acq_rel". + scope : str + The scope of the fence. One of "cta", "cluster", "gpu", "sys". + + Returns + ------- + call : PrimExpr + The call expression. + """ + _choice("sem", sem, _FENCE_SEM) + _choice("scope", scope, _FENCE_SCOPE) + return call_intrin("", "tirx.ptx_fence", sem, scope) + + +def ptx_fence_proxy_async(space: str = ""): + """TVM intrinsic for PTX fence.proxy.async instruction. + + Generates: fence.proxy.async[.{space}]; + + Parameters + ---------- + space : str + The address space qualifier. One of "", "global", "shared::cta", "shared::cluster". + Empty string means no qualifier. + + Returns + ------- + call : PrimExpr + The call expression. + """ + _choice("space", space, _FENCE_PROXY_ASYNC_SPACE) + return call_intrin("", "tirx.ptx_fence_proxy_async", space) + + +def ptx_mbarrier_init(bar, thread_count): + """TVM intrinsic to call mbarrier.init.shared::cta.b64 + + Parameters + ---------- + bar : Var + The pointer to barrier variable. + + thread_count : int + The number of threads expected to arrive at the barrier. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_mbarrier_init", bar, thread_count) + + +def ptx_mbarrier_arrive(bar, cta_id=None, pred=None): + """TVM intrinsic to call + mbarrier.arrive.shared::cta.b64 + or + @p mapa.shared::cluster.u32 + @p mbarrier.arrive.shared::cluster.b64 + + Parameters + ---------- + bar : Var + The pointer to barrier variable. + + cta_id : Optional[PrimExpr] + The cta id. + + pred : Optional[PrimExpr] + The predicate to guard the operation. + """ + if cta_id is None and pred is None: + return call_intrin("", "tirx.ptx_mbarrier_arrive", bar) + assert cta_id is not None and pred is not None + return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred) + + +def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None): + """TVM intrinsic to call + mbarrier.arrive_expect_tx.shared::cta.b64 + or + @p mapa.shared::cluster.u32 + @p mbarrier.arrive_expect_tx.shared::cluster.b64 + + Parameters + ---------- + bar : Var + The pointer to barrier variable. + + byte_count : int + Increases the tx count of the mbarrier object to track completion of + addtional async transactions. + + cta_id : Optional[PrimExpr] + The cta id. + + pred : Optional[PrimExpr] + The predicate to guard the operation. + + Returns + ------- + call : PrimExpr + The call expression. + """ + if cta_id is None and pred is None: + return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, byte_count) + assert cta_id is not None and pred is not None + return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, byte_count, cta_id, pred) + + +def ptx_mbarrier_try_wait(bar, phase): + """TVM intrinsic to call mbarrier.try_wait.parity repeatedly until it returns true + + Parameters + ---------- + bar : Var + The pointer to barrier variable. + + phase : int + The phase of the barrier. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_mbarrier_try_wait", bar, phase) + + +def ptx_mbarrier_try_wait_once(bar, phase, ticks): + """TVM intrinsic for one-shot non-blocking ``mbarrier.try_wait.parity``. + + Returns ``1`` if the requested parity has been reached and ``0`` otherwise. + This is intended for bounded debug waits; production waits should use + :func:`ptx_mbarrier_try_wait`. + """ + return call_intrin("uint32", "tirx.ptx_mbarrier_try_wait_once", bar, phase, ticks) + + +def ptx_bar_arrive(name_bar_id, thread_count): + """TVM intrinsic to call bar.arrive a, b + + Parameters + ---------- + name_bar_id : int + The ID of the named barrier. + + thread_count : int + The number of threads expected to arrive at the barrier. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_bar_arrive", name_bar_id, thread_count) + + +def ptx_bar_sync(name_bar_id, thread_count): + """TVM intrinsic to call bar.sync a, {b} + + Parameters + ---------- + name_bar_id : int + The ID of the named barrier. + + thread_count : int + The number of threads expected to arrive at the barrier. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_bar_sync", name_bar_id, thread_count) + + +def ptx_cp_async( + dst_ptr, + src_ptr, + cp_size, + *, + cache_hint="", + cache_policy=None, + prefetch_size=-1, + predicate=-1, + fill_mode="", +): + """TVM intrinsic for ptx async copy from global to shared memory using cp.async + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async + + Dispatches to one of three PTX-form-aligned ops: + + * ``ptx_cp_async_src_size`` for ``fill_mode == "zero"`` (zero-fill via + ``src_size = pred ? cp_size : 0``). + * ``ptx_cp_async_ignore_src`` for a non-empty ``predicate`` with no + fill_mode (``setp+@p`` guards the asm). + * ``ptx_cp_async_plain`` for the no-predicate / no-fill_mode case. + + Parameters + ---------- + shared_ptr : PrimExpr + The pointer to the shared memory. + + global_ptr : PrimExpr + The pointer to the global memory. + + cp_size : int + The data size to copy. + + cache_hint : str["evict_last", "evict_first", "evict_normal", ""] + The cache hint. + + prefetch_size : int[-1, 64, 128, 256] + The prefetch size. + + predicate : PrimExpr + The predicate to guard the operation. + + fill_mode : str["zero", ""] + The fill mode. + + Returns + ------- + call : PrimExpr + The call expression. + """ + cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) + _choice("prefetch_size", prefetch_size, _CP_ASYNC_PREFETCH_SIZE) + _choice("fill_mode", fill_mode, _CP_ASYNC_FILL_MODE) + return call_intrin( + "", + "tirx.ptx_cp_async", + dst_ptr, + src_ptr, + cp_size, + cache_policy, + int(has_cache_policy), + prefetch_size, + predicate, + fill_mode, + ) + + +def ptx_cp_async_legacy(*all_args): + """Legacy ``ptx_cp_async`` API taking explicit src/dst offsets. + + Signature: ``(dst_ptr, dst_offset, src_ptr, src_offset, cp_size)``. + Offsets are folded into the pointers via ``tvm_access_ptr`` then + dispatched to fork-native :func:`ptx_cp_async`. + + ``T.ptx.cp_async_legacy`` runs through ``_dtype_forward`` which + prepends a ``dtype=`` kwarg as a leading positional. The dtype names + the *element* type of the buffer (offsets are in elements of that + dtype, not bytes), so this function accepts either 5 or 6 positional + args. + """ + args = list(all_args) + elem_dtype = "int8" + if len(args) == 6: + # Leading positional is the buffer element dtype, used to scale + # offsets correctly when folding via ``tvm_access_ptr``. + elem_dtype = args.pop(0) + if len(args) != 5: + raise ValueError( + f"ptx_cp_async_legacy expects 5 args (or 6 with dtype= kwarg " + f"prepended); got {len(all_args)}" + ) + dst_ptr, dst_offset, src_ptr, src_offset, cp_size = args + dst_ptr = tvm_access_ptr(elem_dtype, dst_ptr, dst_offset, 1, 1) + src_ptr = tvm_access_ptr(elem_dtype, src_ptr, src_offset, 1, 1) + return ptx_cp_async(dst_ptr, src_ptr, cp_size) + + +def ptx_cp_async_commit_group(): + """TVM intrinsic for ptx async copy commit + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_cp_async_commit_group") + + +def ptx_cp_async_wait_group(num=0): + """TVM intrinsic for ptx async copy wait + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group + + Parameters + ---------- + num : int, optional + The number of the most recent uncommitted pending cp.async groups to wait. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_cp_async_wait_group", num) + + +def ptx_cp_async_bulk_tensor_global_to_cluster( + dim, dst_ptr, bar, tensormap_addr, cta_mask, cta_group, cache_hint, *coords, cache_policy=None +): + """TVM intrinsic to call cp.async.bulk.tensor.dim.shared::cluster.global.tile.mbarrier::complete_tx::bytes + + Parameters + ---------- + dim : int + The dimension of the source tensor. + + dst_ptr : PrimExpr + The destination pointer to the shared memory. + + bar : PrimExpr + The pointer to mbarrier variable. + + tensormap_addr : PrimExpr + The generic address of the tensor map object. + + cta_mask : int + The mask of the cta for multicast. + + cta_group : int + Must be either 1 or 2. + If set to 1, mbarrier must be in the shared memory of the same CTA as the shared memory destination + If set to 2, mbarrier can be in shared memory of either the same CTA as the shared memory destination + or the shared memory of the peer CTA. + + cache_hint : str + The cache hint. + + coords : List[PrimExpr] + specifies the starting coordinates in the tensor data in the global memory + + Returns + ------- + call : PrimExpr + The call expression. + """ # noqa: E501 + _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) + if isinstance(cache_hint, PrimExpr): + has_cache_policy, *coords = coords + return call_intrin( + "", + "tirx.ptx_cp_async_bulk_tensor_global_to_cluster", + dim, + dst_ptr, + bar, + tensormap_addr, + cta_mask, + cta_group, + cache_hint, + has_cache_policy, + *coords, + ) + cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) + return call_intrin( + "", + "tirx.ptx_cp_async_bulk_tensor_global_to_cluster", + dim, + dst_ptr, + bar, + tensormap_addr, + cta_mask, + cta_group, + cache_policy, + int(has_cache_policy), + *coords, + ) + + +def ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster( + dim, dst_ptr, bar, tensormap_addr, cta_mask, cta_group, cache_hint, *coords, cache_policy=None +): + """TVM intrinsic to call + cp.async.bulk.tensor.dim.shared::cluster.global.tile::gather4.mbarrier::complete_tx::bytes + + Parameters + ---------- + dim : int + The dimension of the source tensor. + + dst_ptr : PrimExpr + The destination pointer to the shared memory. + + bar : PrimExpr + The pointer to mbarrier variable. + + tensormap_addr : PrimExpr + The generic address of the tensor map object. + + cta_mask : int + The mask of the cta for multicast. + + cta_group : int + Must be either 1 or 2. + + cache_hint : str + The cache hint. + + coords : List[PrimExpr] + The TMA coordinates followed by the 4 gather row indices. + + Returns + ------- + call : PrimExpr + The call expression. + """ + _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) + if isinstance(cache_hint, PrimExpr): + has_cache_policy, *coords = coords + return call_intrin( + "", + "tirx.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster", + dim, + dst_ptr, + bar, + tensormap_addr, + cta_mask, + cta_group, + cache_hint, + has_cache_policy, + *coords, + ) + cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) + return call_intrin( + "", + "tirx.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster", + dim, + dst_ptr, + bar, + tensormap_addr, + cta_mask, + cta_group, + cache_policy, + int(has_cache_policy), + *coords, + ) + + +def ptx_cp_async_bulk_tensor_shared_to_global( + dim, src_ptr, tensormap_addr, cache_hint, *coords, cache_policy=None +): + """TVM intrinsic to call cp.async.bulk.tensor.dim.global.shared::cta.tile.bulk_group + + Parameters + ---------- + dim : int + The dimension of the copy tensor. + + src_ptr : PrimExpr + The source pointer to the shared memory. + + tensormap_addr : PrimExpr + The generic address of the tensor map object. + + cache_hint : str + The cache hint. + + coords : List[PrimExpr] + specifies the starting coordinates in the tensor data in the global memory + + Returns + ------- + call : PrimExpr + The call expression. + """ + if isinstance(cache_hint, PrimExpr): + has_cache_policy, *coords = coords + return call_intrin( + "", + "tirx.ptx_cp_async_bulk_tensor_shared_to_global", + dim, + src_ptr, + tensormap_addr, + cache_hint, + has_cache_policy, + *coords, + ) + cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) + return call_intrin( + "", + "tirx.ptx_cp_async_bulk_tensor_shared_to_global", + dim, + src_ptr, + tensormap_addr, + cache_policy, + int(has_cache_policy), + *coords, + ) + + +def ptx_cp_async_bulk_tensor_global_to_cluster_prefetch( + dim, tensormap_addr, cache_hint, *coords, cache_policy=None +): + """TVM intrinsic to call cp.async.bulk.prefetch.tensor.dim.L2.global.tile + + Parameters + ---------- + dim : int + The dimension of the source tensor. + + tensormap_addr : PrimExpr + The generic address of the tensor map object. + + cache_hint : str + The cache hint. + + coords : List[PrimExpr] + specifies the starting coordinates in the tensor data in the global memory + + Returns + ------- + call : PrimExpr + The call expression. + """ + if isinstance(cache_hint, PrimExpr): + has_cache_policy, *coords = coords + return call_intrin( + "", + "tirx.ptx_cp_async_bulk_tensor_global_to_cluster_prefetch", + dim, + tensormap_addr, + cache_hint, + has_cache_policy, + *coords, + ) + cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) + return call_intrin( + "", + "tirx.ptx_cp_async_bulk_tensor_global_to_cluster_prefetch", + dim, + tensormap_addr, + cache_policy, + int(has_cache_policy), + *coords, + ) + + +def ptx_cp_async_bulk_tensor_shared_to_global_reduce( + dim, src_ptr, tensormap_addr, cache_hint, red_op, *coords, cache_policy=None +): + """TVM intrinsic to call cp.reduce.async.bulk.tensor.dim.dst.src.redOp + + Parameters + ---------- + dim : int + The dimension of the copy tensor. + + src_ptr : PrimExpr + The source pointer to the shared memory. + + tensormap_addr : PrimExpr + The generic address of the tensor map object. + + cache_hint: str + The cache hint. + + red_op: str + The reduction operator. + + coords: List[PrimExpr] + The coordinates of the tensor. + + Returns + ------- + call : PrimExpr + The call expression. + """ + if isinstance(cache_hint, PrimExpr): + has_cache_policy = red_op + red_op, *coords = coords + _choice("red_op", red_op, _CP_ASYNC_BULK_RED_OP) + return call_intrin( + "", + "tirx.ptx_cp_async_bulk_tensor_shared_to_global_reduce", + dim, + src_ptr, + tensormap_addr, + cache_hint, + has_cache_policy, + red_op, + *coords, + ) + cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) + _choice("red_op", red_op, _CP_ASYNC_BULK_RED_OP) + return call_intrin( + "", + "tirx.ptx_cp_async_bulk_tensor_shared_to_global_reduce", + dim, + src_ptr, + tensormap_addr, + cache_policy, + int(has_cache_policy), + red_op, + *coords, + ) + + +def ptx_cp_async_bulk_commit_group(): + """TVM intrinsic to call cp.async.bulk.tensor.commit_group + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_cp_async_bulk_commit_group") + + +def ptx_cp_async_bulk_wait_group(n=0, read=True): + """TVM intrinsic to call cp.async.bulk.tensor.wait_group + + Parameters + ---------- + n : int + The number of the most recent uncommitted pending cp.async groups to wait. + + read : bool + Whether the wait is for read. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_cp_async_bulk_wait_group", n, read) + + +def ptx_barrier_cluster_arrive(sem="", aligned=True): + """TVM intrinsic to call barrier.cluster.arrive{.sem}{.aligned} + + Parameters + ---------- + sem : str + Either release or relaxed or empty string. + + aligned : bool + Whether all threads in the warp must execute the same instruction. + """ + _choice("sem", sem, _CLUSTER_BARRIER_SEM) + return call_intrin("", "tirx.ptx_barrier_cluster_arrive", sem, aligned) + + +def ptx_barrier_cluster_wait(acquire=False, aligned=True): + """TVM intrinsic to call barrier.cluster.wait{.acquire}{.aligned} + + Parameters + ---------- + acquire : bool + The memory synchronization + + aligned : bool + Whether all threads in the warp must execute the same instruction. + """ + return call_intrin("", "tirx.ptx_barrier_cluster_wait", acquire, aligned) + + +def ptx_elect_sync(): + """TVM intrinsic to call elect.sync""" + return call_intrin("uint32", "tirx.ptx_elect_sync") + + +def ptx_fence_mbarrier_init(): + """TVM intrinsic for PTX fence.mbarrier_init.release.cluster instruction. + + Generates: fence.mbarrier_init.release.cluster; + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_fence_mbarrier_init") + + +def ptx_fetch_register(bits, reg_name): + """TVM intrinsic to tvm instrinsics to fetch PTX pre-defined registers + + Parameters + ---------- + bits : int + The number of bits of the register. + + reg_name : str + The name of the register. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("int" + str(bits), "tirx.ptx_fetch_register", bits, reg_name) + + +def ptx_mma( + shape, + a_layout, + b_layout, + d_type, + a_type, + b_type, + c_type, + d_ptrs, + a_ptrs, + b_ptrs, + c_ptrs=None, + saturate=False, + bit_op=None, +): + """TVM intrinsic for ptx tensor core mma instructions. + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma + + Each per-thread register of every operand is addressed by its OWN pointer + (one ``void*`` per b32/f32 register), so the register fragments need not be + contiguous in the register file. ``d_ptrs`` / ``a_ptrs`` / ``b_ptrs`` / + ``c_ptrs`` are lists of one pointer per 32-bit register (b32 for + fp16/bf16/tf32/int8 multiplicands, f32/f64 for the accumulator), enumerated + in the fixed PTX register order (see the gemm dispatch / + ``tests/python/tirx-base/test_tir_ptx_mma.py``). + + Within one b32 register the packed elements (e.g. 2 fp16 along k_pack) + must stay contiguous (stride 1); only the b32 registers themselves may be + scattered. + + Parameters + ---------- + shape : str + The shape of mma fragment. + + a_layout : Literal["row", "col"] + The layout of multiplicand fragment A. + + b_layout : Literal["row", "col"] + The layout of multiplicand fragment B. + + d_type : str + The data type of result fragment D. + + a_type : str + The data type of multiplicand fragment A. + + b_type : str + The data type of multiplicand fragment B. + + c_type : str + The data type of accumulator fragment C. + + d_ptrs : List[PrimExpr] + One pointer per result-fragment D register, in PTX order. + + a_ptrs : List[PrimExpr] + One pointer per multiplicand-A register, in PTX order. + + b_ptrs : List[PrimExpr] + One pointer per multiplicand-B register, in PTX order. + + c_ptrs : Optional[List[PrimExpr]] + One pointer per accumulator-C register, in PTX order. ``None`` (the + default) means the accumulator is not used (beta == 0): codegen feeds + a literal 0 for each C slot. + + saturate : bool + The optional saturation at the output. + + bit_op : Optional[Literal["xor", "and"]] + The 1-bit operator (for the b1 subbyte form). ``None`` means unused. + + Returns + ------- + call : PrimExpr + The call expression. + """ + d_ptrs = list(d_ptrs) + a_ptrs = list(a_ptrs) + b_ptrs = list(b_ptrs) + has_c = c_ptrs is not None + c_ptrs = list(c_ptrs) if has_c else [] + + # Encode group register counts as leading attrs so codegen can slice the + # flat pointer tail. ``no_c_ptr`` mirrors the legacy IntImm(0) sentinel. + no_c_ptr = not has_c + # Flattened pointer list: D regs, A regs, B regs, then C regs (if any). + ptrs = [*d_ptrs, *a_ptrs, *b_ptrs, *c_ptrs] + + base = [ + "", + "tirx.ptx_mma", + shape, + a_layout, + b_layout, + d_type, + a_type, + b_type, + c_type, + len(d_ptrs), + len(a_ptrs), + len(b_ptrs), + len(c_ptrs), + no_c_ptr, + *ptrs, + saturate, + ] + if bit_op is None: + return call_intrin(*base) + return call_intrin(*base, bit_op) + + +def ptx_mma_legacy(*all_args, operator=None): + """Legacy ``ptx_mma`` API. + + Signature: ``(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, + multiplicand_a, a_index, multiplicand_b, b_index, accumulator, + c_index, saturate, operator=None)``. The accumulator is reused as + both input and output (no separate ``d``/``c`` slot), unlike + fork-native :func:`ptx_mma` which distinguishes them. Translation: + + * ``a_dtype, b_dtype, c_dtype`` → fork ``a_type, b_type, c_type`` + (and reuse ``c_dtype`` as fork ``d_type`` since the accumulator + dtype is the output dtype here). + * ``(a_ptr, a_offset)`` and ``(b_ptr, b_offset)`` → folded via + :func:`tvm_access_ptr`. + * ``(accumulator, c_index)`` → folded; passed for both ``d_ptr`` and + ``c_ptr`` since the accumulator is reused as the output. + + ``T.ptx.mma.legacy`` runs through ``_dtype_forward`` which prepends a + ``dtype=`` kwarg as a leading positional, so this function accepts + either 13 or 14 positional args. + """ + args = list(all_args) + # ``T.ptx.mma.legacy(..., dtype="...")`` has the dtype prepended by + # ``_dtype_forward``; strip it here. + if len(args) in (14, 15): + _ = args.pop(0) + if len(args) == 14: + # operator passed positionally as the trailing arg. + operator = args.pop() + if len(args) != 13: + raise ValueError( + f"ptx_mma_legacy expects 13-15 positional args (with optional " + f"leading ``call_dtype`` from dtype= kwarg and optional trailing " + f"``operator``); got {len(all_args)}" + ) + ( + shape, + a_layout, + b_layout, + a_dtype, + b_dtype, + c_dtype, + a_ptr, + a_offset, + b_ptr, + b_offset, + acc_ptr, + c_offset, + saturate, + ) = args + # Emit tirx.ptx_mma_legacy directly with separate (ptr_var, offset) + # pairs. codegen_cuda.cc uses C pointer arithmetic ``ptr + offset`` + # so element offsets stay element-accurate, and lower_warp_memory + # rewrites the offset's group component to a thread-local index. + call_args = [ + shape, + a_layout, + b_layout, + a_dtype, + b_dtype, + c_dtype, + a_ptr, + a_offset, + b_ptr, + b_offset, + acc_ptr, + c_offset, + saturate, + ] + if operator is not None: + call_args.append(operator) + return call_intrin("", "tirx.ptx_mma_legacy", *call_args) + + +def ptx_mma_sp_legacy(*all_args): + """Legacy ``ptx_mma_sp`` API. + + Signature: ``(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, + multiplicand_a, a_index, multiplicand_b, b_index, accumulator, + c_index, metadata, meta_index, sparse_selector, saturate)``. + + ``T.ptx.mma_sp.legacy`` runs through ``_dtype_forward`` which prepends + a ``dtype=`` kwarg as a leading positional, so this function accepts + either 16 or 17 positional args. + """ + args = list(all_args) + if len(args) == 17: + _ = args.pop(0) + if len(args) != 16: + raise ValueError( + f"ptx_mma_sp_legacy expects 16 args (or 17 with dtype= kwarg " + f"prepended); got {len(all_args)}" + ) + ( + shape, + a_layout, + b_layout, + a_dtype, + b_dtype, + c_dtype, + a_ptr, + a_offset, + b_ptr, + b_offset, + acc_ptr, + c_offset, + meta_ptr, + meta_offset, + sparse_selector, + saturate, + ) = args + return ptx_mma_sp( + c_dtype, + shape, + a_layout, + b_layout, + a_dtype, + b_dtype, + c_dtype, + a_ptr, + a_offset, + b_ptr, + b_offset, + acc_ptr, + c_offset, + meta_ptr, + meta_offset, + sparse_selector, + saturate, + ) + + +def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride): + """Store the result of PTX MMA into a destination pointer.""" + + return call_intrin(dtype, "tirx.mma_store", m, n, dst_ptr, src_ptr, src_offset, dst_stride) + + +def mma_store_legacy(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride): + """mma_store with apache-style pointer/offset semantics.""" + + return call_intrin( + dtype, + "tirx.mma_store_legacy", + m, + n, + dst_ptr, + src_ptr, + src_offset, + dst_stride, + ) + + +def mma_fill(dtype, local_size, local_ptr, offset): + """Zero-initialize an MMA accumulation register.""" + + return call_intrin(dtype, "tirx.mma_fill", local_size, local_ptr, offset) + + +def mma_fill_legacy(dtype, local_size, local_ptr, offset): + """mma_fill with apache-style pointer/offset semantics.""" + + return call_intrin(dtype, "tirx.mma_fill_legacy", local_size, local_ptr, offset) + + +def ptx_ldmatrix(trans, num, dtype, smem_ptr, *dst_handles): + """TVM intrinsic for ldmatrix.sync.aligned.m8n8.x{num}{.trans}.shared.{dtype}. + + Mirrors the PTX ISA destination form: each output register is a separate + operand. Pass ``T.address_of(buf[idx])`` (or ``buf.ptr_to([idx])``) for + each destination — the slots may be non-contiguous. + + Parameters + ---------- + trans : bool + Apply the ``.trans`` modifier. + num : int + One of 1, 2, 4 — number of m8n8 fragments. + dtype : str + ``"b16"`` (4 bytes per fragment register) or ``"b8"`` (2 bytes per). + smem_ptr : PrimExpr + Generic pointer to source shared memory. + *dst_handles : PrimExpr + N pointer-to-uint32 destinations, where + ``N = num if dtype == "b16" else num // 2``. + + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix + """ + _choice("num", num, _LDMATRIX_NUM) + _choice("dtype", dtype, _LDMATRIX_DTYPE) + # _LDMATRIX_DTYPE entries carry leading dot (".b16" / ".b8"). + dtype_bare = dtype.lstrip(".") if isinstance(dtype, str) else dtype + n_regs = int(num) if dtype_bare == "b16" else int(num) // 2 + if len(dst_handles) != n_regs: + raise ValueError( + f"ldmatrix .x{int(num)}.{dtype_bare} expects {n_regs} destination " + f"handles, got {len(dst_handles)}" + ) + return call_intrin("", "tirx.ptx_ldmatrix", trans, num, dtype, smem_ptr, *dst_handles) + + +_PTX_TO_NUMPY_DTYPE = { + "fp16": "float16", + "fp32": "float32", + "fp64": "float64", + "bf16": "bfloat16", + "tf32": "float32", + "s8": "int8", + "u8": "uint8", + "s32": "int32", + "s4": "int4", + "u4": "uint4", + "b1": "int1", + "b16": "uint16", + "e4m3": "float8_e4m3fn", + "e5m2": "float8_e5m2", +} + + +def _ptx_to_numpy_dtype(dtype_str): + """Map a PTX-abbreviation or numpy dtype string to a numpy dtype string + suitable for ``tvm_access_ptr`` (which scales the offset by the element + bit width). Unknown strings pass through unchanged so a caller may also + pass an already-numpy dtype.""" + s = dtype_str if isinstance(dtype_str, str) else str(dtype_str) + return _PTX_TO_NUMPY_DTYPE.get(s, s) + + +def _wrap_or_fold_access_ptr(ptr, offset, elem_dtype): + """Wrap ``ptr`` with ``tvm_access_ptr`` unless it already is one. + + Several s_tir tensor intrinsics already pass ``buffer.access_ptr(...)`` + (an ``tvm_access_ptr`` Call) for the pointer argument. Naively wrapping + that again yields a nested ``tvm_access_ptr(... access_ptr(...) ...)`` + whose ``args[1]`` is a Call rather than a Var, which crashes the + lowering rule (Downcast at intrin_rule.cc) and several s_tir + passes that assume a raw buffer var. Detect that case and fold the + outer offset into the inner one. + """ + + is_access_ptr_call = ( + isinstance(ptr, Call) and isinstance(ptr.op, Op) and ptr.op.name == "tirx.tvm_access_ptr" + ) + if is_access_ptr_call: + # Inner Call already wraps the buffer var. Reuse its inner var and + # inner element dtype (the marker type_annotation), and add the + # outer offset (which is in `elem_dtype` units, same convention as + # the inner since both come from the same buffer). + inner_args = ptr.args + inner_marker = inner_args[0] + inner_var = inner_args[1] + inner_offset = inner_args[2] + rw_mask = inner_args[4] + return call_intrin( + "handle", + "tirx.tvm_access_ptr", + inner_marker, + inner_var, + inner_offset + offset, + 1, + rw_mask, + ) + return tvm_access_ptr(elem_dtype, ptr, offset, 1, 1) + + +def ptx_ldmatrix_legacy(*all_args): + """Legacy ``ptx_ldmatrix`` API taking explicit offsets. + + Signature: ``(trans, num, dtype, local_ptr, local_offset, smem_ptr, + smem_offset)``. Offsets are folded into the pointers via + ``tvm_access_ptr`` and dispatched to the fork-native + :func:`ptx_ldmatrix`. + + ``T.ptx.ldmatrix_legacy`` runs through ``_dtype_forward`` which + prepends a ``dtype=`` kwarg as a leading positional naming the buffer + element type — offsets are in elements of that dtype, not bytes, so + we forward it to ``tvm_access_ptr`` for correct scaling. + """ + if len(all_args) == 8: + elem_dtype, trans, num, dtype, local_ptr, local_offset, smem_ptr, smem_offset = all_args + elif len(all_args) == 7: + trans, num, dtype, local_ptr, local_offset, smem_ptr, smem_offset = all_args + elem_dtype = "int8" + else: + raise ValueError( + f"ptx_ldmatrix_legacy expects 7 args (or 8 with dtype= kwarg " + f"prepended); got {len(all_args)}" + ) + # Call.dtype carries the buffer element type so codegen can pick the + # int8+trans manual-loop fallback (ldmatrix can't transpose int8). + return call_intrin( + elem_dtype, + "tirx.ptx_ldmatrix_legacy", + trans, + num, + dtype, + local_ptr, + local_offset, + smem_ptr, + smem_offset, + ) + + +def ptx_stmatrix(trans, num, dtype, smem_ptr, *src_handles, shape="m8n8", space="shared"): + """TVM intrinsic for ``stmatrix.sync.aligned.shape.x{num}{.trans}.space.{dtype}``. + + Mirrors :func:`ptx_ldmatrix`: each source register is a separate operand. + Pass ``T.address_of(buf[idx])`` (or ``buf.ptr_to([idx])``) for each + source — the slots may be non-contiguous. + + Parameters + ---------- + trans : bool + Apply the ``.trans`` modifier (required for ``shape == "m16n8"``). + num : int + One of 1, 2, 4 — number of m8n8 fragments per warp. + dtype : str + ``".b16"`` (4 bytes per fragment register) or ``".b8"`` (2 bytes per). + smem_ptr : PrimExpr + Destination pointer in shared memory. + *src_handles : PrimExpr + ``num`` pointer-to-uint32 sources. + shape : str, keyword-only, default "m8n8" + ``"m8n8"`` or ``"m16n8"``. + space : str, keyword-only, default "shared" + ``"shared"`` or ``"shared::cta"``. + + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-stmatrix + """ + _choice("num", num, _LDMATRIX_NUM) + _choice("dtype", dtype, _LDMATRIX_DTYPE) + if shape not in ("m8n8", "m16n8"): + raise ValueError(f"Unsupported stmatrix shape {shape!r}") + if space not in ("shared", "shared::cta"): + raise ValueError(f"Unsupported stmatrix state space {space!r}") + if shape == "m16n8" and not trans: + raise ValueError("stmatrix .m16n8 requires .trans") + n_regs = int(num) + if len(src_handles) != n_regs: + dtype_bare = dtype.lstrip(".") if isinstance(dtype, str) else dtype + raise ValueError( + f"stmatrix .x{int(num)}.{dtype_bare} expects {n_regs} source " + f"handles, got {len(src_handles)}" + ) + return call_intrin( + "", "tirx.ptx_stmatrix", trans, num, dtype, shape, space, smem_ptr, *src_handles + ) + + +def ptx_wgmma_encode_matrix_descriptor(desc, addr, ldo, sdo, swizzle): + """TVM intrinsic to create memory descriptor for wgmma instructions + + Parameters + ---------- + desc : PrimExpr + The pointer to the shared memory descriptor. + + addr : PrimExpr + The address of the matrix. + + ldo : PrimExpr + The leading dimension offset. + + sdo : PrimExpr + The stride dimension offset. + + swizzle : int + The swizzle value (CUtensorMapSwizzle_enum). + """ + return call_intrin("", "tirx.ptx_wgmma_encode_matrix_descriptor", desc, addr, ldo, sdo, swizzle) + + +def ptx_wgmma_noop_barrier(reg): + """TVM intrinsic to call "" : "+{format}"(reg)::"memory" + + Parameters + ---------- + reg : PrimExpr + The register to fence. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_wgmma_noop_barrier", reg) + + +def ptx_wgmma_mma_async_ss( + descA, descB, *accums, M, N, K, in_dtype, out_dtype, transA, transB, scaleA, scaleB, scaleD +): + """TVM intrinsic to call wgmma.mma_async.sync.aligned.shape.dtype.atype.btype over 2 smem operators + + Parameters + ---------- + M : int + The number of rows in matrix A and D. + + N : int + The number of columns in matrix B and D. + + K : int + The number of columns in matrix A and rows in matrix B. + + in_dtype : str + The data type of the input matrices. + + out_type : str + The data type of the output matrices. + + transA : bool + True for M/N major, False for K major. + + transB : bool + True for M/N major, False for K major. + + scaleA : float + The scaling factor for matrix A. + + scaleB : float + The scaling factor for matrix B. + + scaleD : PrimExpr + True: D = A * B + D, False: D = A * B. + + descA : PrimExpr + The SMEM descriptor of matrix A + + descB : PrimExpr + The SMEM descriptor of matrix B + + accums : list + The accumulators registers. + """ # noqa: E501 + return call_intrin( + "", + "tirx.ptx_wgmma_mma_async_ss", + M, + N, + K, + in_dtype, + out_dtype, + transA, + transB, + scaleA, + scaleB, + scaleD, + descA, + descB, + *accums, + ) + + +def ptx_wgmma_mma_async_rs( + descB, *reg_list, M, N, K, in_dtype, out_dtype, transA, transB, scaleA, scaleB, scaleD +): + """TVM intrinsic to call wgmma.mma_async.sync.aligned.shape.dtype.atype.btype + When A is in register and B is in shared memory + + Parameters + ---------- + M : int + The number of rows in matrix A and D. + + N : int + The number of columns in matrix B and D. + + K : int + The number of columns in matrix A and rows in matrix B. + + in_dtype : str + The data type of the input matrices. + + out_type : str + The data type of the output matrices. + + transA : bool + True for M/N major, False for K major. + + transB : bool + True for M/N major, False for K major. + + scaleA : float + The scaling factor for matrix A. + + scaleB : float + The scaling factor for matrix B. + + scaleD : PrimExpr + True: D = A * B + D, False: D = A * B. + + descB : PrimExpr + The SMEM descriptor of matrix B + + reg_list : list + The A registers and accumulators registers. + """ + return call_intrin( + "", + "tirx.ptx_wgmma_mma_async_rs", + M, + N, + K, + in_dtype, + out_dtype, + transA, + transB, + scaleA, + scaleB, + scaleD, + descB, + *reg_list, + ) + + +def ptx_wgmma_fence(): + """TVM intrinsic to call wgmma.fence.sync.aligned + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_wgmma_fence") + + +def ptx_wgmma_commit_group(): + """TVM intrinsic to call wgmma.commit_group.sync.aligned + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_wgmma_commit_group") + + +def ptx_wgmma_wait_group(n): + """TVM intrinsic to call wgmma.wait_group.sync.aligned + + Parameters + ---------- + n : int + The number of the most recent uncommitted pending wgmma groups to wait. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_wgmma_wait_group", n) + + +def ptx_setmaxnreg(inc: bool, reg_count): + """TVM intrinsic to call setmaxnreg.action.sync.aligned.u32 imm-reg-count + + Parameters + ---------- + inc : bool + True to increase the register count, False to decrease. + + reg_count : int + The register count. + """ + return call_intrin("", "tirx.ptx_setmaxnreg", inc, reg_count) + + +def ptx_tcgen05_alloc(dst_ptr, n_cols, cta_group=1): + """TVM intrinsic to call tcgen05.alloc.cta_group.sync.aligned + Dynamically allocates the number of cols in tensor memory, and write + the address of allocated memory to shared memory. + + Parameters + ---------- + dst_ptr : Var + The pointer to the destination shared memory. + + n_cols : int + The number of columns to allocate in tensor memory. + Must be a multiple of 32 and a power of 2, and within the range [32, 512]. + + cta_group : int + The number of CTA groups involved in the allocation. + If cta_group=1, one warp from CTA performs the allocation. Else, if cta_group=2, + one warp from each of the peer CTAs perform the allocation. + """ + _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) + return call_intrin("", "tirx.ptx_tcgen05_alloc", dst_ptr, n_cols, cta_group) + + +def ptx_tcgen05_dealloc(taddr, n_cols, cta_group=1): + """TVM intrinsic to call tcgen05.dealloc.cta_group.sync.aligned + Deallocates the tensor memory specified by the tensor memory address taddr. + + Parameters + ---------- + taddr : PrimExpr + The address of previously allocated tensor memory, should be uint32_t. + + n_cols : int + The number of columns to deallocate in tensor memory. + Must be a multiple of 32 and a power of 2, and within the range [32, 512]. + + cta_group : int + The number of CTA groups involved in the deallocation. + If cta_group=1, one warp from CTA performs the deallocation. Else, if cta_group=2, + one warp from each of the peer CTAs perform the deallocation. + """ + _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) + return call_intrin("", "tirx.ptx_tcgen05_dealloc", taddr, n_cols, cta_group) + + +def ptx_tcgen05_relinquish_alloc_permit(cta_group=1): + """TVM intrinsic to call tcgen05.relinquish_alloc_permit.cta_group.sync.aligned + The CTA of the executing thread is relinquishing the right to allocate + Tensor Memory after calling this op. + + Parameters + ---------- + cta_group : int + The number of CTA groups involved in relinquishing. + If cta_group=1, one warp from CTA performs the relinquishing. Else, if cta_group=2, + one warp from each of the peer CTAs perform the relinquishing. + """ + _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) + return call_intrin("", "tirx.ptx_tcgen05_relinquish_alloc_permit", cta_group) + + +def ptx_tcgen05_encode_matrix_descriptor(desc, addr, ldo, sdo, swizzle): + """TVM intrinsic to create memory descriptor for tcgen05 instructions + + Parameters + ---------- + desc : PrimExpr + The pointer to the shared memory descriptor. + + addr : PrimExpr + The address of the matrix. + + ldo : PrimExpr + The leading dimension offset. + + sdo : PrimExpr + The stride dimension offset. + + swizzle : int + The swizzle value (CUtensorMapSwizzle_enum). + """ + return call_intrin( + "", "tirx.ptx_tcgen05_encode_matrix_descriptor", desc, addr, ldo, sdo, swizzle + ) + + +def ptx_tcgen05_encode_instr_descriptor( + desc, + *, + d_dtype, + a_dtype, + b_dtype, + M, + N, + K, + trans_a, + trans_b, + n_cta_groups=1, + neg_a=False, + neg_b=False, + sat_d=False, + is_sparse=False, +): + """TVM intrinsic to create instruction descriptor for tcgen05 MMA without block scaling + + Parameters + ---------- + desc : PrimExpr + The pointer to the instruction descriptor. + + d_dtype : str + The datatype of resultant matrix D. + + a_dtype : str + The datatype of multiplicand matrix A. + + b_dtype : str + The datatype of multiplicand matrix B. + + M : int + The size of non-reduction dimension of Matrix A. + + N : int + The size of non-reduction dimension of Matrix B. + + K : int + The size of reduction dimension of Matrix A/B. + + trans_a : bool + Whether the multiplicand matrix A is transposed. + True for M/N major, False for K major. + + trans_b : bool + Whether the multiplicand matrix B is transposed. + True for M/N major, False for K major. + + n_cta_groups : int + The number of CTA groups involved in the MMA operation. + + neg_a : bool + Whether to negate the multiplicand matrix A. + + neg_b : bool + Whether to negate the multiplicand matrix B. + + sat_d : bool + Whether to saturate the resultant matrix D. + + is_sparse : bool + Whether the MMA operation is sparse. + """ + _choice("n_cta_groups", n_cta_groups, _TCGEN05_CTA_GROUP) + return call_intrin( + "", + "tirx.ptx_tcgen05_encode_instr_descriptor", + desc, + d_dtype, + a_dtype, + b_dtype, + M, + N, + K, + trans_a, + trans_b, + n_cta_groups, + neg_a, + neg_b, + sat_d, + is_sparse, + ) + + +def ptx_tcgen05_encode_instr_descriptor_block_scaled( + desc, + *, + d_dtype, + a_dtype, + b_dtype, + sfa_dtype, + sfb_dtype, + sfa_tmem_addr, + sfb_tmem_addr, + M, + N, + K, + trans_a, + trans_b, + n_cta_groups=1, + neg_a=False, + neg_b=False, + is_sparse=False, +): + """TVM intrinsic to create instruction descriptor for tcgen05 MMA with block scaling + + Parameters + ---------- + desc : PrimExpr + The pointer to the instruction descriptor. + + d_dtype : str + The datatype of resultant matrix D. + + a_dtype : str + The datatype of multiplicand matrix A. + + b_dtype : str + The datatype of multiplicand matrix B. + + sfa_dtype : str + The datatype of scale factor matrix A. + + sfb_dtype : str + The datatype of scale factor matrix B. + + sfa_tmem_addr : PrimExpr + The address of the scale factor matrix A in tensor memory, should be uint32_t. + + sfb_tmem_addr : PrimExpr + The address of the scale factor matrix B in tensor memory, should be uint32_t. + + M : int + The size of non-reduction dimension of Matrix A. + + N : int + The size of non-reduction dimension of Matrix B. + + K : int + The size of reduction dimension of Matrix A/B. + + trans_a : bool + Whether the multiplicand matrix A is transposed. + True for M/N major, False for K major. + + trans_b : bool + Whether the multiplicand matrix B is transposed. + True for M/N major, False for K major. + + n_cta_groups : int + The number of CTA groups involved in the MMA operation. + + neg_a : bool + Whether to negate the multiplicand matrix A. + + neg_b : bool + Whether to negate the multiplicand matrix B. + + is_sparse : bool + Whether the MMA operation is sparse. + """ + _choice("n_cta_groups", n_cta_groups, _TCGEN05_CTA_GROUP) + return call_intrin( + "", + "tirx.ptx_tcgen05_encode_instr_descriptor_block_scaled", + desc, + d_dtype, + a_dtype, + b_dtype, + sfa_dtype, + sfb_dtype, + sfa_tmem_addr, + sfb_tmem_addr, + M, + N, + K, + trans_a, + trans_b, + n_cta_groups, + neg_a, + neg_b, + is_sparse, + ) + + +def ptx_tcgen05_mma( + d_tmem_addr, + a_operand, + b_desc, + i_desc, + *disable_output_lane, + d_dtype, + a_dtype, + b_dtype, + use_a_tmem, + cta_group, + enable_input_d=1, + scale_input_d=0, + pred=None, +): + """TVM intrinsic to call tcgen05.mma.cta_group.kind without block scaling. + + Parameters + ---------- + d_dtype : str + The datatype of resultant matrix D. + + a_dtype : str + The datatype of multiplicand matrix A. + + b_dtype : str + The datatype of multiplicand matrix B. + + d_tmem_addr : PrimExpr + The address of the resultant matrix D in tensor memory, should be uint32_t. + + a_operand : PrimExpr + Either the matrix descriptor of multiplicand matrix A in shared memory, + or the address of the multiplicand matrix A in tensor memory (uint32_t). + + b_desc : PrimExpr + The matrix descriptor of multiplicand matrix B in shared memory. + + i_desc : PrimExpr + The instruction descriptor of the MMA operation. + + use_a_tmem : bool + Whether the multiplicand matrix A is in tensor memory. + + cta_group : int + The number of CTA groups involved in the MMA operation. + + enable_input_d : PrimExpr + Scale operand for the input accumulator C/D. The inline asm tests + `enable_input_d != 0`: zero means D = A*B, non-zero means D = A*B + D. + + scale_input_d : int + The optional scaling factor to scale input matrix D. + D = A*B+D * (2 ^ - scale-input-d) + + disable_output_lane : list + The lanes that should not be updated in the resultant matrix D. + + pred : Optional[PrimExpr] + Runtime ``uint32`` instruction-level predicate. When given, emit + ``@p_issue tcgen05.mma...`` with ``p_issue = (pred != 0)``. Preserves + PTX-level predicate semantics (single predicated SASS instruction). + """ + + _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) + + # default value for disable_output_lane + if len(disable_output_lane) == 0: + disable_output_lane = [0] * (4 if cta_group == 1 else 8) + + args = [ + d_dtype, + a_dtype, + b_dtype, + d_tmem_addr, + a_operand, + b_desc, + i_desc, + use_a_tmem, + cta_group, + enable_input_d, + scale_input_d, + *disable_output_lane, + ] + if pred is not None: + args.append(pred) + return call_intrin("", "tirx.ptx_tcgen05_mma", *args) + + +def ptx_tcgen05_mma_block_scale( + d_tmem_addr, + a_operand, + b_desc, + sfa_tmem_addr, + sfb_tmem_addr, + i_desc, + *, + d_dtype, + a_dtype, + b_dtype, + sfa_dtype, + sfb_dtype, + use_a_tmem, + cta_group, + enable_input_d=1, +): + """TVM intrinsic to call tcgen05.mma.cta_group.kind.block_scale + Performs matrix multiplication with block scaling: + (A * scale_A) * (B * scale_B) + D + + Parameters + ---------- + d_dtype : str + The datatype of resultant matrix D. + + a_dtype : str + The datatype of multiplicand matrix A. + + b_dtype : str + The datatype of multiplicand matrix B. + + sfa_dtype : str + The datatype of scale factor matrix A. + + sfb_dtype : str + The datatype of scale factor matrix B. + + d_tmem_addr : PrimExpr + The address of the resultant matrix D in tensor memory, should be uint32_t. + + a_operand : PrimExpr + Either the matrix descriptor of multiplicand matrix A in shared memory, + or the address of the multiplicand matrix A in tensor memory (uint32_t). + + b_desc : PrimExpr + The matrix descriptor of multiplicand matrix B in shared memory. + + sfa_tmem_addr : PrimExpr + The address of the scale factor matrix A in tensor memory, should be uint32_t. + + sfb_tmem_addr : PrimExpr + The address of the scale factor matrix B in tensor memory, should be uint32_t. + + i_desc : PrimExpr + The instruction descriptor of the MMA operation. + + use_a_tmem : bool + Whether the multiplicand matrix A is in tensor memory. + + cta_group : int + The number of CTA groups involved in the MMA operation. + + enable_input_d : PrimExpr + Scale operand for the input accumulator C/D. Zero means D = A*B, + non-zero means D = A*B + D. + """ + + _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) + return call_intrin( + "", + "tirx.ptx_tcgen05_mma_block_scale", + d_dtype, + a_dtype, + b_dtype, + sfa_dtype, + sfb_dtype, + d_tmem_addr, + a_operand, + b_desc, + sfa_tmem_addr, + sfb_tmem_addr, + i_desc, + use_a_tmem, + cta_group, + enable_input_d, + ) + + +def ptx_tcgen05_mma_sp( + d_tmem_addr, + a_operand, + b_desc, + sp_tmem_addr, + i_desc, + *disable_output_lane, + d_dtype, + a_dtype, + b_dtype, + use_a_tmem, + cta_group, + enable_input_d=1, + scale_input_d=0, +): + """TVM intrinsic to call tcgen05.mma.sp.cta_group.kind without block scaling. + + Parameters + ---------- + d_dtype : str + The datatype of resultant matrix D. + + a_dtype : str + The datatype of multiplicand matrix A. + + b_dtype : str + The datatype of multiplicand matrix B. + + d_tmem_addr : PrimExpr + The address of the resultant matrix D in tensor memory, should be uint32_t. + + a_operand : PrimExpr + Either the matrix descriptor of multiplicand matrix A in shared memory, + or the address of the multiplicand matrix A in tensor memory (uint32_t). + + b_desc : PrimExpr + The matrix descriptor of multiplicand matrix B in shared memory. + + sp_tmem_addr : PrimExpr + The address of the metadata of sparse matrix in tensor memory, should be uint32_t. + + i_desc : PrimExpr + The instruction descriptor of the MMA operation. + + use_a_tmem : bool + Whether the multiplicand matrix A is in tensor memory. + + cta_group : int + The number of CTA groups involved in the MMA operation. + + enable_input_d : PrimExpr + Scale operand for the input accumulator C/D. The inline asm tests + `enable_input_d != 0`: zero means D = A*B, non-zero means D = A*B + D. + + scale_input_d : int + The optional scaling factor to scale input matrix D. + D = A*B+D * (2 ^ - scale-input-d) + + disable_output_lane : list + The lanes that should not be updated in the resultant matrix D. + """ + + _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) + + # default value for disable_output_lane + if len(disable_output_lane) == 0: + disable_output_lane = [0] * (4 if cta_group == 1 else 8) + + return call_intrin( + "", + "tirx.ptx_tcgen05_mma_sp", + d_dtype, + a_dtype, + b_dtype, + d_tmem_addr, + a_operand, + b_desc, + sp_tmem_addr, + i_desc, + use_a_tmem, + cta_group, + enable_input_d, + scale_input_d, + *disable_output_lane, + ) + + +def ptx_tcgen05_mma_sp_block_scale( + d_tmem_addr, + a_operand, + b_desc, + sfa_tmem_addr, + sfb_tmem_addr, + sp_tmem_addr, + i_desc, + *, + d_dtype, + a_dtype, + b_dtype, + sfa_dtype, + sfb_dtype, + use_a_tmem, + cta_group, + enable_input_d=1, +): + """TVM intrinsic to call tcgen05.mma.sp.cta_group.kind.block_scale + Performs sparse matrix multiplication with block scaling: + (A * scale_A) * (B * scale_B) + D + + Parameters + ---------- + d_dtype : str + The datatype of resultant matrix D. + + a_dtype : str + The datatype of multiplicand matrix A. + + b_dtype : str + The datatype of multiplicand matrix B. + + sfa_dtype : str + The datatype of scale factor matrix A. + + sfb_dtype : str + The datatype of scale factor matrix B. + + d_tmem_addr : PrimExpr + The address of the resultant matrix D in tensor memory, should be uint32_t. + + a_operand : PrimExpr + Either the matrix descriptor of multiplicand matrix A in shared memory, + or the address of the multiplicand matrix A in tensor memory (uint32_t). + + b_desc : PrimExpr + The matrix descriptor of multiplicand matrix B in shared memory. + + sfa_tmem_addr : PrimExpr + The address of the scale factor matrix A in tensor memory, should be uint32_t. + + sfb_tmem_addr : PrimExpr + The address of the scale factor matrix B in tensor memory, should be uint32_t. + + sp_tmem_addr : PrimExpr + The address of the metadata of sparse matrix in tensor memory, should be uint32_t. + + i_desc : PrimExpr + The instruction descriptor of the MMA operation. + + use_a_tmem : bool + Whether the multiplicand matrix A is in tensor memory. + + cta_group : int + The number of CTA groups involved in the MMA operation. + + enable_input_d : PrimExpr + Scale operand for the input accumulator C/D. Zero means D = A*B, + non-zero means D = A*B + D. + """ + _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) + return call_intrin( + "", + "tirx.ptx_tcgen05_mma_sp_block_scale", + d_dtype, + a_dtype, + b_dtype, + sfa_dtype, + sfb_dtype, + d_tmem_addr, + a_operand, + b_desc, + sfa_tmem_addr, + sfb_tmem_addr, + sp_tmem_addr, + i_desc, + use_a_tmem, + cta_group, + enable_input_d, + ) + + +def ptx_tcgen05_fence_before_thread_sync(): + """TVM intrinsic to call tcgen05.fence::before_thread_sync + Orders all prior asynchronous tcgen05 operations relative to subsequent operations. + """ + return call_intrin("", "tirx.ptx_tcgen05_fence_before_thread_sync") + + +def ptx_tcgen05_fence_after_thread_sync(): + """TVM intrinsic to call tcgen05.fence::after_thread_sync + Orders all subsequent asynchronous tcgen05 operations relative to previous operations. + """ + return call_intrin("", "tirx.ptx_tcgen05_fence_after_thread_sync") + + +def _choice(name: str, value, options): + """Validate `value` is one of `options`. Raise a clear ValueError otherwise. + + Symbolic values (Var, non-constant PrimExpr) are accepted without + validation; specialization later replaces them with concrete values + that the C-side intrinsic body re-checks. + """ + # Concrete int / IntImm value: validate. + try: + concrete = int(value) + except (TypeError, ValueError): + return # symbolic; defer check + if concrete not in options: + raise ValueError(f"invalid {name}={concrete!r}; expected one of {tuple(options)}") + + +# See top-of-file imports for `_FENCE_SEM` etc. (re-exported from _common). +# Note: TCGEN05_LDST_SHAPES values must stay in sync with the shape branches +# of codegen_ptx_tcgen05_ld/_st in intrinsics/cuda/tcgen05.py. + + +def ptx_tcgen05_cp( + taddr, src_desc, *, shape, cta_group=1, multicast="", decompress="", row=0, col=0 +): + """TVM intrinsic for the Blackwell `tcgen05.cp` PTX instruction. + + The emitted PTX is:: + + tcgen05.cp.cta_group::{cta_group}.{shape}[.{multicast}][.{decompress}] [taddr], src_desc; + + Each keyword argument maps 1:1 to a PTX token: read the call and you + know what instruction is emitted. + + Parameters + ---------- + taddr : PrimExpr + Destination tensor-memory address (uint32). Callers typically pass + ``tmem_base + column_offset_in_uint32s`` directly. Use the optional + ``row`` / ``col`` keyword arguments only when the address needs + runtime row/col composition via ``get_tmem_addr`` (high 16 bits row, + low 16 bits col). + + src_desc : PrimExpr + The 64-bit shared-memory matrix descriptor. + + shape : str + One of ``"32x128b"``, ``"4x256b"``, ``"128x128b"``, ``"128x256b"``, + ``"64x128b"``. + + cta_group : int + 1 or 2. + + multicast : str + One of ``""``, ``"warpx4"``, ``"warpx2::02_13"``, ``"warpx2::01_23"``. + ``"32x128b"`` requires ``"warpx4"``; ``"64x128b"`` requires one of the + ``warpx2::*`` values; other shapes require ``""``. + + decompress : str + Trailing PTX suffix for fp4/fp6 → fp8 on-the-fly decompression. + One of ``""``, ``"b8x16.b4x16_p64"``, ``"b8x16.b6x16_p32"``. + + row, col : PrimExpr + Optional row/col offsets added to ``taddr`` at runtime. Default 0. + """ + _choice("shape", shape, _TCGEN05_CP_SHAPES) + _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) + _choice("multicast", multicast, _TCGEN05_CP_MULTICAST) + _choice("decompress", decompress, _TCGEN05_CP_DECOMPRESS) + if shape == "32x128b" and multicast != "warpx4": + raise ValueError(f"shape=32x128b requires multicast='warpx4', got {multicast!r}") + if shape == "64x128b" and multicast not in ("warpx2::02_13", "warpx2::01_23"): + raise ValueError(f"shape=64x128b requires multicast in warpx2::*, got {multicast!r}") + if shape in ("128x128b", "128x256b", "4x256b") and multicast != "": + raise ValueError(f"shape={shape} requires multicast='', got {multicast!r}") + + return call_intrin( + "", + "tirx.ptx_tcgen05_cp", + taddr, + src_desc, + shape, + cta_group, + multicast, + decompress, + row, + col, + ) + + +def ptx_tcgen05_shift(taddr, cta_group=1): + """TVM intrinsic to call tcgen05.shift.cta_group.down + Asynchronously shift down the rows of the matrix in Tensor Memory for a warp. + + Parameters + ---------- + taddr : PrimExpr + The address of matrix in tensor memory, should be uint32_t. + + cta_group : int + The number of CTA groups involved in the shift. + If cta_group=1, shift operation is performed in the Tensor Memory of current CTA. + Else, shift operation is performed in the Tensor Memory of both the current CTA and + the peer CTA. + """ + _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) + return call_intrin("", "tirx.ptx_tcgen05_shift", taddr, cta_group) + + +def ptx_tcgen05_ld(src_addr, *regs, shape, num, row=0, col=0, pack=False): + """TVM intrinsic for tcgen05.ld.sync.aligned — async collective load from TMEM. + + Emits ``tcgen05.ld.sync.aligned.{shape}.x{num}[.pack::16b].b32 {regs}, [addr];`` + + Parameters + ---------- + src_addr : PrimExpr + Tensor-memory source address (uint32). + + regs : list[PrimExpr] + Destination registers. Count depends on shape x num. + + shape : str + One of ``"16x32bx2"``, ``"16x64b"``, ``"16x128b"``, ``"16x256b"``, ``"32x32b"``. + + num : int + Repeat factor along the columns. Power-of-two in [1, 128]. + + row, col : PrimExpr + Optional TMEM row/col offsets added to ``src_addr`` at runtime (row must be + a multiple of 32). Default 0. + + pack : bool + Pack two 16-bit chunks into a single 32-bit register. + """ + _choice("shape", shape, _TCGEN05_LDST_SHAPES) + return call_intrin("", "tirx.ptx_tcgen05_ld", src_addr, row, col, shape, num, pack, *regs) + + +def ptx_tcgen05_st(dst_addr, *regs, shape, num, row=0, col=0, unpack=False): + """TVM intrinsic for tcgen05.st.sync.aligned — async collective store to TMEM. + + Emits ``tcgen05.st.sync.aligned.{shape}.x{num}[.unpack::16b].b32 [addr], {regs};`` + + Parameters + ---------- + dst_addr : PrimExpr + Tensor-memory destination address (uint32). + + regs : list[PrimExpr] + Source registers. Count depends on shape x num. + + shape : str + One of ``"16x32bx2"``, ``"16x64b"``, ``"16x128b"``, ``"16x256b"``, ``"32x32b"``. + + num : int + Repeat factor along the columns. Power-of-two in [1, 128]. + + row, col : PrimExpr + Optional TMEM row/col offsets added to ``dst_addr`` at runtime (row must be + a multiple of 32). Default 0. + + unpack : bool + Unpack a 32-bit register into two 16-bit chunks. + """ + _choice("shape", shape, _TCGEN05_LDST_SHAPES) + return call_intrin("", "tirx.ptx_tcgen05_st", dst_addr, row, col, shape, num, unpack, *regs) + + +def ptx_tcgen05_wait_ld(): + """TVM intrinsic to call tcgen05.wait::ld.sync.aligned + Wait for the completion of all prior async tcgen05.ld operations. + """ + return call_intrin("", "tirx.ptx_tcgen05_wait_ld") + + +def ptx_tcgen05_wait_st(): + """TVM intrinsic to call tcgen05.wait::st.sync.aligned + Wait for the completion of all prior async tcgen05.st operations. + """ + return call_intrin("", "tirx.ptx_tcgen05_wait_st") + + +def ptx_tcgen05_commit(bar, cta_group=1, cta_mask=0, *, pred=None): + """TVM intrinsic to call tcgen05.commit.cta_group + + Parameters + ---------- + bar : PrimExpr + The pointer to mbarrier variable. + + cta_group: int + The number of CTA groups involved in previous tcgen05 operations. + + cta_mask : int + The mask of the CTAs in the cluster, used for multicast. + + pred : Optional[PrimExpr] + Runtime ``uint32`` predicate. When given, emit + ``@p tcgen05.commit...`` with ``p = (pred != 0)``. This preserves + PTX-level instruction predicate semantics (single predicated + instruction in SASS), distinct from a C-level ``if`` branch. + + Returns + ------- + call : PrimExpr + The call expression. + """ + _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) + args = [bar, cta_group, cta_mask] + if pred is not None: + args.append(pred) + return call_intrin("", "tirx.ptx_tcgen05_commit", *args) + + +def timer_init_cuda(profiler_buffer, profiler_tag, profiler_write_offset, num_groups, group_id): + """TVM intrinsic for initializing the CUDA profiler, and store profiling result in a buffer. + + Parameters + ---------- + profiler_buffer: Var + The buffer to store the profiling result. + + profiler_tag: Var + Buffer of length 1 storing the base tag of the current thread. + + profiler_write_offset: Var + Buffer of length 1 storing the offset in buffer to write the next + profiling result for the current thread. + + num_groups: int + The number of groups in the profiler. + + group_id: PrimExpr + The group id of the current thread. + + Returns + ------- + call : PrimExpr + The call expression. + """ + + return call_intrin( + "handle", + "tirx.timer_init_cuda", + profiler_buffer, + profiler_tag, + profiler_write_offset, + num_groups, + group_id, + ) + + +def timer_start_cuda( + event_type, + profiler_buffer, + profiler_tag, + profiler_write_offset, + profiler_write_stride, + leader_cond, +): + """TVM intrinsic for starting the timer for profiling a specific event, and storing profiling result in a buffer. + + Parameters + ---------- + event_type: Enum + The event to profile. + + profiler_buffer: Var + The buffer to store the profiling result. + + profiler_tag: Var + Buffer of length 1 storing the base tag of the current thread. + + profiler_write_offset: Var + Buffer of length 1 storing the offset in buffer to write the next + profiling result for the current thread. + + profiler_write_stride: int + The stride to advance in buffer in the next write. + + leader_cond: PrimExpr + The condition to check if the current thread is the leader. + + Returns + ------- + call : PrimExpr + The call expression. + """ # noqa: E501 + + return call_intrin( + "handle", + "tirx.timer_start_cuda", + event_type.value, + profiler_buffer, + profiler_tag, + profiler_write_offset, + profiler_write_stride, + leader_cond, + ) + + +def timer_end_cuda( + event_type, + profiler_buffer, + profiler_tag, + profiler_write_offset, + profiler_write_stride, + leader_cond, +): + """TVM intrinsic for ending the timer for profiling a specific event, and storing profiling result in a buffer. + + Parameters + ---------- + event_type: Enum + The event to profile. + + profiler_buffer: Var + The buffer to store the profiling result. + + profiler_tag: Var + Buffer of length 1 storing the base tag of the current thread. + + profiler_write_offset: Var + Buffer of length 1 storing the offset in buffer to write the next + profiling result for the current thread. + + profiler_write_stride: int + The stride to advance in buffer in the next write. + + leader_cond: PrimExpr + The condition to check if the current thread is the leader. + + Returns + ------- + call : PrimExpr + The call expression. + """ # noqa: E501 + + return call_intrin( + "handle", + "tirx.timer_end_cuda", + event_type.value, + profiler_buffer, + profiler_tag, + profiler_write_offset, + profiler_write_stride, + leader_cond, + ) + + +def timer_finalize_cuda( + profiler_buffer, profiler_tag, profiler_write_offset, profiler_write_stride, leader_cond +): + """TVM intrinsic for finalizing the CUDA profiler, and store profiling result in a buffer. + + Parameters + ---------- + profiler_buffer: Var + The buffer to store the profiling result. + + profiler_tag: Var + Buffer of length 1 storing the base tag of the current thread. + + profiler_write_offset: Var + Buffer of length 1 storing the offset in buffer to write the next + profiling result for the current thread. + + profiler_write_stride: int + The stride to advance in buffer in the next write. + + leader_cond: PrimExpr + The condition to check if the current thread is the leader. + + Returns + ------- + call : PrimExpr + The call expression. + """ + + return call_intrin( + "handle", + "tirx.timer_finalize_cuda", + profiler_buffer, + profiler_tag, + profiler_write_offset, + profiler_write_stride, + leader_cond, + ) + + +def cuda_atomic_add(res_addr, value): + """TVM intrinsic to call cuda atomic add instruction + + Parameters + ---------- + res_addr : PrimExpr + The result address. + + value: PrimExpr + The value to add. + + Returns + ------- + call : PrimExpr + The call expression. + """ + value = tir.convert(value) + return call_intrin(value.dtype, "tirx.cuda_atomic_add", res_addr, value) + + +def cuda_thread_fence(): + """TVM intrinsic to call cuda thread fence instruction + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.cuda_thread_fence") + + +def cuda_warpgroup_sync(bar_no): + """TVM intrinsic to synchronize a CUDA warpgroup via a named barrier. + + Parameters + ---------- + bar_no : PrimExpr + The named barrier id to use for the warpgroup. + + Notes + ----- + Synchronizes 128 threads in a warpgroup using `bar.sync bar_no, 128`. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.cuda_warpgroup_sync", bar_no) + + +def cuda_syncthreads_and(cond): + """TVM intrinsic to call cuda syncthreads_and instruction + + Parameters + ---------- + cond: PrimExpr + The condition. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("int64", "tirx.cuda_syncthreads_and", cond) + + +def cuda_syncthreads_or(cond): + """TVM intrinsic to call cuda syncthreads_or instruction + + Parameters + ---------- + cond: PrimExpr + The condition. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("int64", "tirx.cuda_syncthreads_or", cond) + + +def cuda_nano_sleep(time): + """TVM intrinsic to call cuda nano sleep instruction + + Parameters + ---------- + time: PrimExpr + The time to sleep. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.cuda_nano_sleep", time) + + +def cuda_printf(fmt, *args): + """TVM intrinsic to call cuda printf instruction + + Parameters + ---------- + fmt: str + The format string. + + *args: list + The arguments to the format string. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.cuda_printf", fmt, *args) + + +def cuda_ldg(addr, dtype): + """TVM intrinsic to call CUDA C++ __ldg() function + + Parameters + ---------- + addr : PrimExpr + The memory address to load. + + dtype : str + The data type of the loaded value. + + Returns + """ + return call_intrin(dtype, "tirx.cuda_ldg", addr, dtype) + + +def cuda_get_tmem_addr(addr, row_offset, col_offset): + """TVM intrinsic to call cuda tmem address calculation + + Parameters + ---------- + addr: PrimExpr + The memory address to calculate. + + row_offset: PrimExpr + The row offset to calculate. + + col_offset: PrimExpr + The column offset to calculate. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("uint32", "tirx.cuda_get_tmem_addr", addr, row_offset, col_offset) + + +def cuda_cvta_generic_to_shared(ptr): + """Convert a generic pointer to a shared-memory address (uint32). + + Wraps ``__cvta_generic_to_shared(ptr)``. Used by op-wrappers that + precompute the shared-memory address at the wrapper layer instead of + inside the asm helper body. + """ + return call_intrin("uint32", "tirx.cuda_cvta_generic_to_shared", ptr) + + +def cuda_smem_addr_from_uint64(cluster_addr): + """Narrow a 64-bit cluster-mapped SMEM address to a 32-bit SMEM address. + + Wraps ``static_cast(cluster_addr)``. Used by + cp.async.bulk.shared::cluster.* op-wrappers. + """ + return call_intrin("uint32", "tirx.cuda_smem_addr_from_uint64", cluster_addr) + + +def cuda_sm100_tma_2sm_mbarrier_addr(bar): + """Compute the SM100 2SM TMA mbarrier shared-address operand.""" + return bitwise_and(cuda_cvta_generic_to_shared(bar), const(0xFEFFFFFF, dtype="uint32")) + + +def ptx_exp2(x): + """TVM intrinsic for PTX fast exp2 approximation (ex2.approx.ftz.f32) + + Parameters + ---------- + x : PrimExpr + The float32 input value. + + Returns + ------- + call : PrimExpr + The call expression returning 2^x (approximate). + """ + return call_intrin("float32", "tirx.ptx_exp2", x) + + +def ptx_rcp(x): + """TVM intrinsic for PTX fast reciprocal approximation (rcp.approx.ftz.f32) + + Parameters + ---------- + x : PrimExpr + The float32 input value. + + Returns + ------- + call : PrimExpr + The call expression returning 1/x (approximate). + """ + return call_intrin("float32", "tirx.ptx_rcp", x) + + +def ptx_any_sync(mask, pred): + """TVM intrinsic for PTX warp-wide any predicate (__any_sync) + + Parameters + ---------- + mask : PrimExpr + The thread mask (uint32). + pred : PrimExpr + The predicate value (int32). + + Returns + ------- + call : PrimExpr + The call expression returning 1 if any thread in mask has pred != 0. + """ + return call_intrin("int32", "tirx.ptx_any_sync", mask, pred) + + +def ptx_reduce3_max_f32(a, b, c): + """TVM intrinsic to call 3-input max.f32 PTX instruction (sm_100a+) + + Parameters + ---------- + a, b, c : PrimExpr + The three float32 values to compare. + + Returns + ------- + call : PrimExpr + The call expression returning max(a, b, c). + """ + return call_intrin("float32", "tirx.ptx_reduce3_max_f32", a, b, c) + + +def ptx_reduce3_min_f32(a, b, c): + """TVM intrinsic to call 3-input min.f32 PTX instruction (sm_100a+) + + Parameters + ---------- + a, b, c : PrimExpr + The three float32 values to compare. + + Returns + ------- + call : PrimExpr + The call expression returning min(a, b, c). + """ + return call_intrin("float32", "tirx.ptx_reduce3_min_f32", a, b, c) + + +def _ptx_binary_arith(op_name, dtype, d, a, b, *, rounding="rn", ftz=False, sat=False): + """Shared helper for add/sub/mul over (f32 | f32x2 | f64), DPS form.""" + _choice("rounding", rounding, _F32X2_ROUND) + if dtype == "f64" and (ftz or sat): + raise ValueError(f"PTX {op_name}.f64 does not accept .ftz or .sat") + if dtype == "f32x2" and sat: + raise ValueError(f"PTX {op_name}.f32x2 does not accept .sat") + return call_intrin( + "", + f"tirx.ptx_{op_name}_{dtype}", + d, + a, + b, + rounding, + int(ftz), + int(sat), + ) + + +def _ptx_fma(dtype, d, a, b, c, *, rounding="rn", ftz=False, sat=False): + """Shared helper for fma over (f32 | f32x2 | f64), DPS form.""" + _choice("rounding", rounding, _F32X2_ROUND) + if dtype == "f64" and (ftz or sat): + raise ValueError("PTX fma.f64 does not accept .ftz or .sat") + if dtype == "f32x2" and sat: + raise ValueError("PTX fma.f32x2 does not accept .sat") + return call_intrin( + "", + f"tirx.ptx_fma_{dtype}", + d, + a, + b, + c, + rounding, + int(ftz), + int(sat), + ) + + +def ptx_add_f32(d_addr, a, b, *, rounding="rn", ftz=False, sat=False): + """PTX ``add{.rnd}{.ftz}{.sat}.f32 [d_addr], a, b`` — DPS form.""" + return _ptx_binary_arith("add", "f32", d_addr, a, b, rounding=rounding, ftz=ftz, sat=sat) + + +def ptx_add_f32x2(d_addr, a, b, *, rounding="rn", ftz=False): + """PTX ``add{.rnd}{.ftz}.f32x2 [d_addr], a, b`` — DPS form. + + a, b are packed-as-uint64 register operands (2 fp32 each). + """ + return _ptx_binary_arith("add", "f32x2", d_addr, a, b, rounding=rounding, ftz=ftz) + + +def ptx_add_f64(d_addr, a, b, *, rounding="rn"): + """PTX ``add{.rnd}.f64 [d_addr], a, b`` — DPS form (no .ftz / .sat).""" + return _ptx_binary_arith("add", "f64", d_addr, a, b, rounding=rounding) + + +def ptx_sub_f32(d_addr, a, b, *, rounding="rn", ftz=False, sat=False): + """PTX ``sub{.rnd}{.ftz}{.sat}.f32 [d_addr], a, b`` — DPS form.""" + return _ptx_binary_arith("sub", "f32", d_addr, a, b, rounding=rounding, ftz=ftz, sat=sat) + + +def ptx_sub_f32x2(d_addr, a, b, *, rounding="rn", ftz=False): + """PTX ``sub{.rnd}{.ftz}.f32x2 [d_addr], a, b`` — DPS form.""" + return _ptx_binary_arith("sub", "f32x2", d_addr, a, b, rounding=rounding, ftz=ftz) + + +def ptx_sub_f64(d_addr, a, b, *, rounding="rn"): + """PTX ``sub{.rnd}.f64 [d_addr], a, b`` — DPS form.""" + return _ptx_binary_arith("sub", "f64", d_addr, a, b, rounding=rounding) + + +def ptx_mul_f32(d_addr, a, b, *, rounding="rn", ftz=False, sat=False): + """PTX ``mul{.rnd}{.ftz}{.sat}.f32 [d_addr], a, b`` — DPS form.""" + return _ptx_binary_arith("mul", "f32", d_addr, a, b, rounding=rounding, ftz=ftz, sat=sat) + + +def ptx_mul_f32x2(d_addr, a, b, *, rounding="rn", ftz=False): + """PTX ``mul{.rnd}{.ftz}.f32x2 [d_addr], a, b`` — DPS form.""" + return _ptx_binary_arith("mul", "f32x2", d_addr, a, b, rounding=rounding, ftz=ftz) + + +def ptx_mul_f64(d_addr, a, b, *, rounding="rn"): + """PTX ``mul{.rnd}.f64 [d_addr], a, b`` — DPS form.""" + return _ptx_binary_arith("mul", "f64", d_addr, a, b, rounding=rounding) + + +def ptx_fma_f32(d_addr, a, b, c, *, rounding="rn", ftz=False, sat=False): + """PTX ``fma{.rnd}{.ftz}{.sat}.f32 [d_addr], a, b, c`` — DPS form.""" + return _ptx_fma("f32", d_addr, a, b, c, rounding=rounding, ftz=ftz, sat=sat) + + +def ptx_fma_f32x2(d_addr, a, b, c, *, rounding="rn", ftz=False): + """PTX ``fma{.rnd}{.ftz}.f32x2 [d_addr], a, b, c`` — DPS form. + + a, b, c are packed-as-uint64 register operands. + """ + return _ptx_fma("f32x2", d_addr, a, b, c, rounding=rounding, ftz=ftz) + + +def ptx_fma_f64(d_addr, a, b, c, *, rounding="rn"): + """PTX ``fma{.rnd}.f64 [d_addr], a, b, c`` — DPS form.""" + return _ptx_fma("f64", d_addr, a, b, c, rounding=rounding) + + +def ptx_max_f32(a, b, *, ftz=False, nan=False): + """TVM intrinsic for PTX ``max{.ftz}{.NaN}.f32 d, a, b``. + + 2-operand form (distinct from :func:`ptx_reduce3_max_f32` which is the + 3-operand SM_100+ form). ``.NaN`` qualifier propagates NaN inputs to + the output; without it, NaN inputs are silently ignored. + + Parameters + ---------- + a, b : PrimExpr + Float32 inputs. + ftz : bool + If True, flush subnormals to zero (``.ftz``). + nan : bool + If True, propagate NaN inputs (``.NaN``). + """ + return call_intrin("float32", "tirx.ptx_max_f32", a, b, int(ftz), int(nan)) + + +def ptx_griddepcontrol_wait(): + """TVM intrinsic for PTX ``griddepcontrol.wait`` (sm_90+). + + Blocks the current grid until prerequisite grids signalled via + :func:`ptx_griddepcontrol_launch_dependents` have finished. Acts as a + full memory barrier. + """ + return call_intrin("", "tirx.ptx_griddepcontrol_wait") + + +def ptx_griddepcontrol_launch_dependents(): + """TVM intrinsic for PTX ``griddepcontrol.launch_dependents`` (sm_90+). + + Signals that the current grid has reached a point where dependent + grids may begin execution. + """ + return call_intrin("", "tirx.ptx_griddepcontrol_launch_dependents") + + +_PTX_LD_SCOPE = {"cta", "cluster", "gpu", "sys"} +_PTX_LD_SPACE = {"global", "shared", "shared::cta", "shared::cluster", "local"} +_PTX_LD_VOLATILE_SPACE = _PTX_LD_SPACE | {"const"} +_PTX_LD_TYPE = {"b32", "u32", "u64", "s32", "f32"} +_PTX_LD_COP = {"", "ca", "cg", "cs", "lu", "cv"} +_PTX_MEM_SCOPE = {"", "cta", "cluster", "gpu", "sys"} +_PTX_MEM_SPACE = {"global", "shared", "shared::cta", "shared::cluster"} +_PTX_SCALAR_TYPE = {"b32", "b64", "u32", "u64", "s32", "s64", "f32", "f64"} +_PTX_RED_OP = {"and", "or", "xor", "add", "inc", "dec", "min", "max"} +_PTX_ATOM_OP = {"and", "or", "xor", "exch", "add", "inc", "dec", "min", "max"} +_PTX_ST_VEC = {"", "v2", "v4", "v8"} +_PTX_ST_COP = {"", "wb", "cg", "cs", "wt"} +_PTX_PREFETCH_TENSORMAP_SPACE = {"", "const", "param"} +_PTX_SCALAR_RETURN_TYPE = { + "b32": "uint32", + "u32": "uint32", + "s32": "int32", + "b64": "uint64", + "u64": "uint64", + "s64": "int64", + "f32": "float32", + "f64": "float64", +} +_PTX_CACHE_POLICY = { + "evict_normal": 0x1000000000000000, + "evict_first": 0x12F0000000000000, + "evict_last": 0x14F0000000000000, +} + + +def _resolve_cache_policy(cache_hint, cache_policy, choices=_CP_ASYNC_BULK_CACHE_HINT): + _choice("cache_hint", cache_hint, choices) + if cache_policy is not None: + return cache_policy, True + if cache_hint: + if cache_hint not in _PTX_CACHE_POLICY: + raise ValueError( + f"Unsupported built-in cache policy {cache_hint!r}; pass cache_policy explicitly" + ) + return const(_PTX_CACHE_POLICY[cache_hint], dtype="uint64"), True + return const(0, dtype="uint64"), False + + +def ptx_ld_acquire(addr, return_type, ptx_type, *, scope="gpu", space="global"): + """TVM intrinsic for scalar PTX ``ld.acquire.scope{.ss}.type`` loads. + + This wrapper covers the scalar no-cache-policy/no-vector instances of the + PTX ISA ``ld.acquire`` form. ``scope``, state ``space``, PTX ``type`` and + TVM ``return_type`` are explicit so callers can request either raw-bit or + typed loads. + + Parameters + ---------- + addr : PrimExpr + The memory address to load. + + return_type : str + TVM dtype returned by the load. + + ptx_type : str + PTX type suffix such as ``"b32"``, ``"u64"``, or ``"s32"``. + + scope : str + PTX memory scope: ``"cta"``, ``"cluster"``, ``"gpu"``, or ``"sys"``. + + space : str + PTX state space suffix. + + Returns + ------- + call : PrimExpr + The loaded value. + """ + _choice("scope", scope, _PTX_LD_SCOPE) + _choice("space", space, _PTX_LD_SPACE) + _choice("ptx_type", ptx_type, _PTX_LD_TYPE) + return call_intrin( + return_type, "tirx.ptx_ld_acquire", addr, return_type, ptx_type, scope, space + ) + + +def ptx_ld( + addr, + return_type, + ptx_type, + *, + weak=False, + space="global", + cop="", + cache_hint="", + cache_policy=None, +): + """TVM intrinsic for scalar PTX ``ld{.weak}{.ss}{.cop}{.level::cache_hint}.type``. + + This wrapper covers scalar no-prefetch/no-vector instances of the weak + generic load form. + """ + _choice("space", space, _PTX_LD_SPACE | {"const", "param::entry", "param::func"}) + _choice("cop", cop, _PTX_LD_COP) + _choice("ptx_type", ptx_type, _PTX_LD_TYPE) + cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) + return call_intrin( + return_type, + "tirx.ptx_ld", + addr, + cache_policy, + return_type, + int(bool(weak)), + space, + cop, + ptx_type, + int(has_cache_policy), + ) + + +def ptx_ld_volatile(addr, return_type, ptx_type, *, space="global"): + """TVM intrinsic for scalar PTX ``ld.volatile{.ss}.type`` loads. + + This wrapper covers scalar no-prefetch/no-vector instances. + """ + _choice("space", space, _PTX_LD_VOLATILE_SPACE) + _choice("ptx_type", ptx_type, _PTX_LD_TYPE) + return call_intrin(return_type, "tirx.ptx_ld_volatile", addr, return_type, ptx_type, space) + + +def ptx_ld_global_acquire(res, addr): + """TVM intrinsic to call the legacy ptx ld.global.acquire helper. + + Parameters + ---------- + res : PrimExpr + The result of the load. + + addr : PrimExpr + The memory address to load. + + Returns + ------- + call : PrimExpr + The call expression. + """ + return call_intrin("", "tirx.ptx_ld_global_acquire", res, addr) + + +def ptx_red_scalar( + address, + value, + *, + sem="", + scope="", + space="global", + op, + ptx_type, + cache_hint="", + cache_policy=None, +): + _choice("scope", scope, _PTX_MEM_SCOPE) + _choice("space", space, _PTX_MEM_SPACE) + _choice("op", op, _PTX_RED_OP) + _choice("ptx_type", ptx_type, _PTX_SCALAR_TYPE) + cache_policy, has_cache_policy = _resolve_cache_policy( + cache_hint, cache_policy, _CP_ASYNC_CACHE_HINT + ) + if sem not in ("", "relaxed", "release"): + raise ValueError(f"Unsupported PTX red sem {sem!r}") + return call_intrin( + "", + "tirx.ptx_red_scalar", + address, + value, + cache_policy, + sem, + scope, + space, + op, + ptx_type, + int(has_cache_policy), + ) + + +def ptx_atom_scalar( + address, + value, + *, + sem="", + scope="", + space="global", + op, + ptx_type, + cache_hint="", + cache_policy=None, +): + _choice("scope", scope, _PTX_MEM_SCOPE) + _choice("space", space, _PTX_MEM_SPACE) + _choice("op", op, _PTX_ATOM_OP) + _choice("ptx_type", ptx_type, _PTX_SCALAR_TYPE) + cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) + if sem not in ("", "relaxed", "acquire", "release", "acq_rel"): + raise ValueError(f"Unsupported PTX atom sem {sem!r}") + return call_intrin( + _PTX_SCALAR_RETURN_TYPE[ptx_type], + "tirx.ptx_atom_scalar", + address, + value, + cache_policy, + sem, + scope, + space, + op, + ptx_type, + int(has_cache_policy), + ) + + +def ptx_st( + address, + *values, + weak=False, + space="shared", + cop="", + vec="", + ptx_type, + cache_hint="", + cache_policy=None, +): + _choice("space", space, _PTX_MEM_SPACE | {"local", "param::func"}) + _choice("cop", cop, _PTX_ST_COP) + _choice("vec", vec, _PTX_ST_VEC) + _choice("ptx_type", ptx_type, _PTX_SCALAR_TYPE) + cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) + return call_intrin( + "", + "tirx.ptx_st", + address, + *values, + cache_policy, + int(bool(weak)), + space, + cop, + vec, + ptx_type, + int(has_cache_policy), + ) + + +def ptx_st_bulk(ptr, num_bytes, *, weak=False, space="shared::cta"): + if space not in ("", "shared::cta"): + raise ValueError(f"Unsupported PTX st.bulk space {space!r}") + return call_intrin("", "tirx.ptx_st_bulk", ptr, num_bytes, int(bool(weak)), space) + + +def ptx_prefetch_tensormap(tensormap_addr, space=""): + _choice("space", space, _PTX_PREFETCH_TENSORMAP_SPACE) + return call_intrin("", "tirx.ptx_prefetch_tensormap", tensormap_addr, space) + + +def ptx_mbarrier_test_wait_parity(barrier, phase, *, sem="", scope="", space="shared::cta"): + if sem not in ("", "acquire", "relaxed"): + raise ValueError(f"Unsupported mbarrier.test_wait.parity sem {sem!r}") + if scope not in ("", "cta", "cluster"): + raise ValueError(f"Unsupported mbarrier.test_wait.parity scope {scope!r}") + if bool(sem) != bool(scope): + raise ValueError("mbarrier.test_wait.parity sem and scope must be set together") + if space not in ("shared", "shared::cta"): + raise ValueError(f"Unsupported mbarrier.test_wait.parity space {space!r}") + return call_intrin( + "uint32", "tirx.ptx_mbarrier_test_wait_parity", barrier, phase, sem, scope, space + ) + + +def ptx_cp_async_bulk_g2s_cta( + dst_ptr, + src_ptr, + num_bytes, + mbarrier_ptr, + *, + cache_hint="", + cache_policy=None, + ignore_oob=False, + ignore_bytes_left=0, + ignore_bytes_right=0, +): + cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) + return call_intrin( + "", + "tirx.ptx_cp_async_bulk_g2s_cta", + dst_ptr, + src_ptr, + num_bytes, + ignore_bytes_left, + ignore_bytes_right, + mbarrier_ptr, + cache_policy, + int(has_cache_policy), + int(bool(ignore_oob)), + ) + + +def ptx_cp_async_bulk_g2s_cluster( + dst_ptr, + src_ptr, + num_bytes, + mbarrier_ptr, + *, + cache_hint="", + cache_policy=None, + multicast=False, + cta_mask=0, +): + cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) + return call_intrin( + "", + "tirx.ptx_cp_async_bulk_g2s_cluster", + dst_ptr, + src_ptr, + num_bytes, + mbarrier_ptr, + cta_mask, + cache_policy, + int(has_cache_policy), + int(bool(multicast)), + ) + + +def ptx_cp_async_bulk_s2s_cluster(dst_ptr, src_ptr, num_bytes, mbarrier): + return call_intrin( + "", "tirx.ptx_cp_async_bulk_s2s_cluster", dst_ptr, src_ptr, num_bytes, mbarrier + ) + + +def ptx_cp_async_bulk_s2g( + dst_ptr, src_ptr, num_bytes, *, cache_hint="", cache_policy=None, cp_mask=False, byte_mask=0 +): + cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) + return call_intrin( + "", + "tirx.ptx_cp_async_bulk_s2g", + dst_ptr, + src_ptr, + num_bytes, + byte_mask, + cache_policy, + int(has_cache_policy), + int(bool(cp_mask)), + ) + + +def ptx_fns_b32(mask, base, offset): + return call_intrin("uint32", "tirx.ptx_fns_b32", mask, base, offset) + + +def ptx_add_rn_f32_bf16(acc, x): + return call_intrin("float32", "tirx.ptx_add_rn_f32_bf16", acc, x) + + +def cuda_uint_as_float(bits): + return call_intrin("float32", "tirx.cuda_uint_as_float", bits) + + +def cuda_float_as_uint(x): + return call_intrin("uint32", "tirx.cuda_float_as_uint", x) + + +def cuda_ballot_sync(mask, pred): + return call_intrin("uint32", "tirx.cuda_ballot_sync", mask, pred) + + +def cuda_ffs_u32(value): + return call_intrin("int32", "tirx.cuda_ffs_u32", value) + + +def cuda_reduce_add_sync_u32(mask, value): + return call_intrin("uint32", "tirx.cuda_reduce_add_sync_u32", mask, value) + + +def cuda_reduce_min_sync_u32(mask, value): + return call_intrin("uint32", "tirx.cuda_reduce_min_sync_u32", mask, value) + + +def cuda_clock64(): + return call_intrin("uint64", "tirx.cuda_clock64") + + +def cuda_make_float2(x, y): + return call_intrin("uint64", "tirx.cuda_make_float2", x, y) + + +def cuda_float2_x(packed): + return call_intrin("float32", "tirx.cuda_float2_x", packed) + + +def cuda_float2_y(packed): + return call_intrin("float32", "tirx.cuda_float2_y", packed) + + +def cuda_fmul2_rn(a, b): + return call_intrin("uint64", "tirx.cuda_fmul2_rn", a, b) + + +def cuda_fadd2_rn(a, b): + return call_intrin("uint64", "tirx.cuda_fadd2_rn", a, b) + + +def cuda_float22bfloat162_rn(v0, v1): + return call_intrin("uint32", "tirx.cuda_float22bfloat162_rn", v0, v1) + + +def cuda_float22bfloat162_rn_from_float2(packed): + return call_intrin("uint32", "tirx.cuda_float22bfloat162_rn_from_float2", packed) + + +def cuda_bfloat1622float2(packed): + return call_intrin("uint64", "tirx.cuda_bfloat1622float2", packed) + + +def cuda_hmin2(a, b): + return call_intrin("uint32", "tirx.cuda_hmin2", a, b) + + +def cuda_hmax2(a, b): + return call_intrin("uint32", "tirx.cuda_hmax2", a, b) + + +def cuda_fp8x4_e4m3_from_float4(x, y, z, w): + return call_intrin("uint32", "tirx.cuda_fp8x4_e4m3_from_float4", x, y, z, w) + + +def ptx_map_shared_rank(ptr, rank): + """TVM intrinsic to call ptx map_shared_rank instruction + + Parameters + ---------- + ptr: PrimExpr + The generic pointer to the local shared memory, handle type + + rank: int + The rank of the distributed shared memory. + + Returns + ------- + call : PrimExpr + The call expression. + """ + + return ptx_mapa(ptr, rank, space="", ptx_type="u64", return_type="uint64") + + +def ptx_mapa(ptr, rank, *, space="", ptx_type="u64", return_type="uint64"): + """TVM intrinsic for PTX ``mapa{.space}.type d, a, b``.""" + if space not in ("", "shared::cluster"): + raise ValueError(f"Unsupported mapa space {space!r}") + if ptx_type not in ("u32", "u64"): + raise ValueError(f"Unsupported mapa type {ptx_type!r}") + return call_intrin(return_type, "tirx.ptx_mapa", ptr, rank, space, ptx_type, return_type) + + +def cuda_atomic_cas(ptr, old_val, new_val): + """TVM intrinsic to call cuda atomic cas instruction + + Parameters + ---------- + ptr: PrimExpr + The pointer to the memory location. + + old_val: PrimExpr + The old value. + + new_val: PrimExpr + The new value. + + Returns + ------- + call : PrimExpr + The call expression. + """ + old_val = tir.convert(old_val) + return call_intrin(old_val.dtype, "tirx.cuda_atomic_cas", ptr, old_val, new_val) + + +######################################################## +# NVSHMEM builtins +######################################################## + + +def nvshmem_my_pe(): + """TVM intrinsic to call nvshmem_my_pe() + + Returns + ------- + call : PrimExpr + The call expression. + """ + + return call_intrin("int32", "tirx.nvshmem_my_pe") + + +def nvshmem_n_pes(): + """TVM intrinsic to call nvshmem_n_pes() + + Returns + ------- + call : PrimExpr + The call expression. + """ + + return call_intrin("int32", "tirx.nvshmem_n_pes") + + +def nvshmem_getmem_nbi(dst, src, nelems, pe): + """TVM intrinsic to call nvshmem_getmem_nbi() + + Parameters + ---------- + dst: PrimExpr + The pointer to the symmetric address or host/device address of the data object to be updated. + + src: PrimExpr + The pointer to the symmetric address of the source data object. + + nelems: int + The number of bytes to get per thread. + + pe: int + The PE number of the remote PE. + + Returns + ------- + call : PrimExpr + The call expression. + """ # noqa: E501 + + return call_intrin("", "tirx.nvshmem_getmem_nbi", dst, src, nelems, pe) + + +def nvshmem_putmem_nbi(dst, src, nelems, pe): + """TVM intrinsic to call nvshmem_putmem_nbi() + + Parameters + ---------- + dst: PrimExpr + The pointer to the symmetric address of the destination data object. + + src: PrimExpr + The pointer to the symmetric address or host/device address of the data object to be copied. + + nelems: int + The number of bytes to put per thread. + + pe: int + The PE number of the remote PE. + + Returns + ------- + call : PrimExpr + The call expression. + """ + + return call_intrin("", "tirx.nvshmem_putmem_nbi", dst, src, nelems, pe) + + +def nvshmem_getmem_nbi_warp(dst, src, nelems, pe): + """TVM intrinsic to call nvshmem_getmem_nbi_warp() + + Parameters + ---------- + dst: PrimExpr + The pointer to the symmetric address or host/device address of the data object to be updated. + + src: PrimExpr + The pointer to the symmetric address of the source data object. + + nelems: int + The number of bytes to get per warp. + + pe: int + The PE number of the remote PE. + + Returns + ------- + call : PrimExpr + The call expression. + """ # noqa: E501 + + return call_intrin("", "tirx.nvshmem_getmem_nbi_warp", dst, src, nelems, pe) + + +def nvshmem_putmem_nbi_warp(dst, src, nelems, pe): + """TVM intrinsic to call nvshmem_putmem_nbi_warp() + + Parameters + ---------- + dst: PrimExpr + The pointer to the symmetric address of the destination data object. + + src: PrimExpr + The pointer to the symmetric address or host/device address of the data object to be copied. + + nelems: int + The number of bytes to put per warp. + + pe: int + The PE number of the remote PE. + + Returns + ------- + call : PrimExpr + The call expression. + """ + + return call_intrin("", "tirx.nvshmem_putmem_nbi_warp", dst, src, nelems, pe) + + +def nvshmem_getmem_nbi_block(dst, src, nelems, pe): + """TVM intrinsic to call nvshmem_getmem_nbi_block() + + Parameters + ---------- + dst: PrimExpr + The pointer to the symmetric address or host/device address of the data object to be updated. + + src: PrimExpr + The pointer to the symmetric address of the source data object. + + nelems: int + The number of bytes to get per block. + + pe: int + The PE number of the remote PE. + + Returns + ------- + call : PrimExpr + The call expression. + """ # noqa: E501 + + return call_intrin("", "tirx.nvshmem_getmem_nbi_block", dst, src, nelems, pe) + + +def nvshmem_putmem_nbi_block(dst, src, nelems, pe): + """TVM intrinsic to call nvshmem_putmem_nbi_block() + + Parameters + ---------- + dst: PrimExpr + The pointer to the symmetric address of the destination data object. + + src: PrimExpr + The pointer to the symmetric address or host/device address of the data object to be copied. + + nelems: int + The number of bytes to put per block. + + pe: int + The PE number of the remote PE. + + Returns + ------- + call : PrimExpr + The call expression. + """ + + return call_intrin("", "tirx.nvshmem_putmem_nbi_block", dst, src, nelems, pe) + + +def nvshmem_signal_op(sig_addr, signal, sig_op, pe): + """TVM intrinsic to call nvshmem_signal_op() + + Parameters + ---------- + sig_addr: PrimExpr + The pointer to the symmetric address of the signal word to be updated, must be uint64_t*. + + signal: uint64_t + The value used to update sig_addr. + + sig_op: str + Operation used to update sig_addr with signal, typical sig_op values are "set" and "add". + + pe: int + The PE number of the remote PE. + + Returns + ------- + call : PrimExpr + The call expression. + """ + + _choice("sig_op", sig_op, _NVSHMEM_SIG_OP) + return call_intrin("", "tirx.nvshmem_signal_op", sig_addr, signal, sig_op, pe) + + +def nvshmem_wait_until(ivar, cmp, cmp_value, type="uint64_t"): + """TVM intrinsic to call nvshmem_wait_until() + + Parameters + ---------- + ivar: PrimExpr + The pointer to the symmetric address of a remotely accessible data object, must be TYPE*. + + cmp: str + The compare operator that compares ivar with cmp_value. + + cmp_value: TYPE + The value to be compared with ivar. + + type: str + The TYPE of ivar and cmp_value. + + Returns + ------- + call : PrimExpr + The call expression. + """ + + _choice("cmp", cmp, _NVSHMEM_CMP) + return call_intrin("", "tirx.nvshmem_wait_until", ivar, cmp, cmp_value, type) + + +def nvshmem_quiet(): + """TVM intrinsic to call nvshmem_quiet() + + Returns + ------- + call : PrimExpr + The call expression. + """ + + return call_intrin("", "tirx.nvshmem_quiet") + + +def nvshmem_putmem_signal_nbi(dst, src, nelems, sig_addr, signal, sig_op, pe): + """TVM intrinsic to call nvshmem_putmem_signal_nbi() + + Parameters + ---------- + dst: PrimExpr + The pointer to the symmetric address of the data object to be updated on the remote PE. + + src: PrimExpr + The pointer to the symmetric address or host/device address of data object containing the data to be copied. + + nelems: int + The number of bytes to put per thread. + + sig_addr: PrimExpr + The pointer to the symmetric address of the signal data object to be updated on the remote PE as a signal, must be uint64_t*. + + signal: uint64_t + The unsigned 64-bit value that is used for updating the remote sig_addr signal data object. + + sig_op: str + Signal operator that represents the type of update to be performed on the remote sig_addr signal data object. + + pe: int + The PE number of the remote PE. + + Returns + ------- + call : PrimExpr + The call expression. + """ # noqa: E501 + + return call_intrin( + "", "tirx.nvshmem_putmem_signal_nbi", dst, src, nelems, sig_addr, signal, sig_op, pe + ) + + +def nvshmem_putmem_signal_nbi_warp(dst, src, nelems, sig_addr, signal, sig_op, pe): + """TVM intrinsic to call nvshmem_putmem_signal_nbi_warp() + + Parameters + ---------- + dst: PrimExpr + The pointer to the symmetric address of the data object to be updated on the remote PE. + + src: PrimExpr + The pointer to the symmetric address or host/device address of data object containing the data to be copied. + + nelems: int + The number of bytes to put per warp. + + sig_addr: PrimExpr + The pointer to the symmetric address of the signal data object to be updated on the remote PE as a signal, must be uint64_t*. + + signal: uint64_t + The unsigned 64-bit value that is used for updating the remote sig_addr signal data object. + + sig_op: str + Signal operator that represents the type of update to be performed on the remote sig_addr signal data object. + + pe: int + The PE number of the remote PE. + + Returns + ------- + call : PrimExpr + The call expression. + """ # noqa: E501 + + return call_intrin( + "", "tirx.nvshmem_putmem_signal_nbi_warp", dst, src, nelems, sig_addr, signal, sig_op, pe + ) + + +def nvshmem_putmem_signal_nbi_block(dst, src, nelems, sig_addr, signal, sig_op, pe): + """TVM intrinsic to call nvshmem_putmem_signal_nbi_block() + + Parameters + ---------- + dst: PrimExpr + The pointer to the symmetric address of the data object to be updated on the remote PE. + + src: PrimExpr + The pointer to the symmetric address or host/device address of data object containing the data to be copied. + + nelems: int + The number of bytes to put per block. + + sig_addr: PrimExpr + The pointer to the symmetric address of the signal data object to be updated on the remote PE as a signal, must be uint64_t*. + + signal: uint64_t + The unsigned 64-bit value that is used for updating the remote sig_addr signal data object. + + sig_op: str + Signal operator that represents the type of update to be performed on the remote sig_addr signal data object. + + pe: int + The PE number of the remote PE. + + Returns + ------- + call : PrimExpr + The call expression. + """ # noqa: E501 + + return call_intrin( + "", "tirx.nvshmem_putmem_signal_nbi_block", dst, src, nelems, sig_addr, signal, sig_op, pe + ) + + +def nvshmem_fence(): + """TVM intrinsic to call nvshmem_fence() + + Returns + ------- + call : PrimExpr + The call expression. + """ + + return call_intrin("", "tirx.nvshmem_fence") + + +def nvshmem_barrier_all(): + """TVM intrinsic to call nvshmem_barrier_all() + + Returns + ------- + call : PrimExpr + The call expression. + """ + + return call_intrin("", "tirx.nvshmem_barrier_all") diff --git a/python/tvm/tirx/backend/adreno/__init__.py b/python/tvm/backend/cuda/operator/__init__.py similarity index 88% rename from python/tvm/tirx/backend/adreno/__init__.py rename to python/tvm/backend/cuda/operator/__init__.py index 06f85d091d5a..744c9bafc98e 100644 --- a/python/tvm/tirx/backend/adreno/__init__.py +++ b/python/tvm/backend/cuda/operator/__init__.py @@ -14,4 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""The TIR Adreno backend passes""" +"""CUDA backend operator registrations and helpers.""" + +__all__ = ["intrinsics", "tile_primitive"] diff --git a/python/tvm/tirx/operator/intrinsics/cuda/__init__.py b/python/tvm/backend/cuda/operator/intrinsics/__init__.py similarity index 100% rename from python/tvm/tirx/operator/intrinsics/cuda/__init__.py rename to python/tvm/backend/cuda/operator/intrinsics/__init__.py diff --git a/python/tvm/tirx/operator/intrinsics/_schema.py b/python/tvm/backend/cuda/operator/intrinsics/_schema.py similarity index 97% rename from python/tvm/tirx/operator/intrinsics/_schema.py rename to python/tvm/backend/cuda/operator/intrinsics/_schema.py index 57e409e9555c..67c8b8cebd31 100644 --- a/python/tvm/tirx/operator/intrinsics/_schema.py +++ b/python/tvm/backend/cuda/operator/intrinsics/_schema.py @@ -43,8 +43,8 @@ from collections.abc import Callable -from tvm.tirx.op import cuda_func_call -from tvm.tirx.operator.intrinsics.cuda.registry import register_codegen +from tvm.backend.cuda.op import cuda_func_call +from tvm.backend.cuda.operator.intrinsics.registry import register_codegen # C primitive type → TVM dtype string. Used when the caller specifies a # non-void ``return_type`` but no explicit ``tvm_return_type`` — the helper diff --git a/python/tvm/tirx/operator/intrinsics/cuda/cp_async.py b/python/tvm/backend/cuda/operator/intrinsics/cp_async.py similarity index 99% rename from python/tvm/tirx/operator/intrinsics/cuda/cp_async.py rename to python/tvm/backend/cuda/operator/intrinsics/cp_async.py index 2eeb0821d666..3e6bc015e81f 100644 --- a/python/tvm/tirx/operator/intrinsics/cuda/cp_async.py +++ b/python/tvm/backend/cuda/operator/intrinsics/cp_async.py @@ -27,9 +27,9 @@ """ import tvm -from tvm.tirx.op import cuda_func_call +from tvm.backend.cuda.op import cuda_func_call -from .._schema import device_intrinsic +from ._schema import device_intrinsic from .registry import CODEGEN_REGISTRY, register_codegen from .utils import parse_str @@ -200,7 +200,8 @@ def codegen_ptx_cp_async(*args): Accepts three call shapes (sorted by arity): * 5 args ``(dst_ptr, dst_offset, src_ptr, src_offset, cp_size)`` — - the legacy form emitted by ``s_tir/transform/InjectPTXAsyncCopy``. + the legacy form emitted by + ``tvm.backend.cuda.transform.InjectPTXAsyncCopy``. Offsets are folded into the pointers via ``tvm_access_ptr`` (in bytes; offsets are pre-scaled by the pass) and the call is forwarded with default cache / predicate / fill_mode. diff --git a/python/tvm/tirx/operator/intrinsics/cuda/header.py b/python/tvm/backend/cuda/operator/intrinsics/header.py similarity index 100% rename from python/tvm/tirx/operator/intrinsics/cuda/header.py rename to python/tvm/backend/cuda/operator/intrinsics/header.py diff --git a/python/tvm/tirx/operator/intrinsics/cuda/math.py b/python/tvm/backend/cuda/operator/intrinsics/math.py similarity index 99% rename from python/tvm/tirx/operator/intrinsics/cuda/math.py rename to python/tvm/backend/cuda/operator/intrinsics/math.py index 37cd57d8714d..d93caa8145f8 100644 --- a/python/tvm/tirx/operator/intrinsics/cuda/math.py +++ b/python/tvm/backend/cuda/operator/intrinsics/math.py @@ -26,9 +26,9 @@ * warp / CTA reductions (templated butterfly shuffle-XOR). """ -from tvm.tirx.op import cuda_func_call +from tvm.backend.cuda.op import cuda_func_call -from .._schema import device_intrinsic +from ._schema import device_intrinsic from .registry import register_codegen from .utils import parse_str, validate_power_of_two_range diff --git a/python/tvm/tirx/operator/intrinsics/cuda/memory.py b/python/tvm/backend/cuda/operator/intrinsics/memory.py similarity index 99% rename from python/tvm/tirx/operator/intrinsics/cuda/memory.py rename to python/tvm/backend/cuda/operator/intrinsics/memory.py index 152e1434ca95..b91a7945f248 100644 --- a/python/tvm/tirx/operator/intrinsics/cuda/memory.py +++ b/python/tvm/backend/cuda/operator/intrinsics/memory.py @@ -33,9 +33,9 @@ """ from tvm import DataType -from tvm.tirx.op import cuda_func_call +from tvm.backend.cuda.op import cuda_func_call -from .._schema import device_intrinsic +from ._schema import device_intrinsic from .registry import CODEGEN_REGISTRY, register_codegen from .utils import parse_str diff --git a/python/tvm/tirx/operator/intrinsics/cuda/misc.py b/python/tvm/backend/cuda/operator/intrinsics/misc.py similarity index 99% rename from python/tvm/tirx/operator/intrinsics/cuda/misc.py rename to python/tvm/backend/cuda/operator/intrinsics/misc.py index 0cca2cd19456..2ee21ba57d06 100644 --- a/python/tvm/tirx/operator/intrinsics/cuda/misc.py +++ b/python/tvm/backend/cuda/operator/intrinsics/misc.py @@ -31,9 +31,9 @@ import json import tvm -from tvm.tirx.op import cuda_func_call +from tvm.backend.cuda.op import cuda_func_call -from .._schema import device_intrinsic +from ._schema import device_intrinsic from .registry import CODEGEN_REGISTRY, register_codegen from .utils import parse_str diff --git a/python/tvm/tirx/operator/intrinsics/cuda/mma.py b/python/tvm/backend/cuda/operator/intrinsics/mma.py similarity index 99% rename from python/tvm/tirx/operator/intrinsics/cuda/mma.py rename to python/tvm/backend/cuda/operator/intrinsics/mma.py index 7c5998736850..5d6bccdcb4b1 100644 --- a/python/tvm/tirx/operator/intrinsics/cuda/mma.py +++ b/python/tvm/backend/cuda/operator/intrinsics/mma.py @@ -32,7 +32,7 @@ from tvm import DataType -from .._schema import device_intrinsic +from ._schema import device_intrinsic from .registry import CODEGEN_REGISTRY, register_codegen from .types import PTXDataType from .utils import parse_str diff --git a/python/tvm/tirx/operator/intrinsics/cuda/nvshmem.py b/python/tvm/backend/cuda/operator/intrinsics/nvshmem.py similarity index 99% rename from python/tvm/tirx/operator/intrinsics/cuda/nvshmem.py rename to python/tvm/backend/cuda/operator/intrinsics/nvshmem.py index af7fa4c9905e..dc922a0e12d8 100644 --- a/python/tvm/tirx/operator/intrinsics/cuda/nvshmem.py +++ b/python/tvm/backend/cuda/operator/intrinsics/nvshmem.py @@ -17,7 +17,7 @@ # pylint: disable=redefined-builtin, invalid-name """NVSHMEM intrinsics. Each backend call is one ``device_intrinsic(...)``.""" -from .._schema import device_intrinsic +from ._schema import device_intrinsic from .registry import CODEGEN_REGISTRY, register_codegen _NVSHMEM = ("nvshmem",) diff --git a/python/tvm/tirx/operator/intrinsics/cuda/registry.py b/python/tvm/backend/cuda/operator/intrinsics/registry.py similarity index 100% rename from python/tvm/tirx/operator/intrinsics/cuda/registry.py rename to python/tvm/backend/cuda/operator/intrinsics/registry.py diff --git a/python/tvm/tirx/operator/intrinsics/cuda/sync.py b/python/tvm/backend/cuda/operator/intrinsics/sync.py similarity index 99% rename from python/tvm/tirx/operator/intrinsics/cuda/sync.py rename to python/tvm/backend/cuda/operator/intrinsics/sync.py index 4386336660a6..0fcdb31a46f1 100644 --- a/python/tvm/tirx/operator/intrinsics/cuda/sync.py +++ b/python/tvm/backend/cuda/operator/intrinsics/sync.py @@ -32,8 +32,14 @@ * warpgroup sync (``bar.sync``) """ -from .._common import CLUSTER_BARRIER_SEM, FENCE_PROXY_ASYNC_SPACE, FENCE_SCOPE, FENCE_SEM -from .._schema import device_intrinsic +from tvm.tirx.operator.intrinsics._common import ( + CLUSTER_BARRIER_SEM, + FENCE_PROXY_ASYNC_SPACE, + FENCE_SCOPE, + FENCE_SEM, +) + +from ._schema import device_intrinsic from .registry import CODEGEN_REGISTRY, register_codegen from .utils import parse_str diff --git a/python/tvm/tirx/operator/intrinsics/cuda/tcgen05.py b/python/tvm/backend/cuda/operator/intrinsics/tcgen05.py similarity index 99% rename from python/tvm/tirx/operator/intrinsics/cuda/tcgen05.py rename to python/tvm/backend/cuda/operator/intrinsics/tcgen05.py index ef30a85d0fd2..029a6bf87ce8 100644 --- a/python/tvm/tirx/operator/intrinsics/cuda/tcgen05.py +++ b/python/tvm/backend/cuda/operator/intrinsics/tcgen05.py @@ -25,7 +25,7 @@ import tvm -from .._schema import device_intrinsic +from ._schema import device_intrinsic from .registry import CODEGEN_REGISTRY, register_codegen from .types import PTXDataType from .utils import parse_str, validate_cta_group, validate_power_of_two_range diff --git a/python/tvm/tirx/operator/intrinsics/cuda/types.py b/python/tvm/backend/cuda/operator/intrinsics/types.py similarity index 100% rename from python/tvm/tirx/operator/intrinsics/cuda/types.py rename to python/tvm/backend/cuda/operator/intrinsics/types.py diff --git a/python/tvm/tirx/operator/intrinsics/cuda/utils.py b/python/tvm/backend/cuda/operator/intrinsics/utils.py similarity index 100% rename from python/tvm/tirx/operator/intrinsics/cuda/utils.py rename to python/tvm/backend/cuda/operator/intrinsics/utils.py diff --git a/python/tvm/tirx/operator/intrinsics/cuda/wgmma.py b/python/tvm/backend/cuda/operator/intrinsics/wgmma.py similarity index 99% rename from python/tvm/tirx/operator/intrinsics/cuda/wgmma.py rename to python/tvm/backend/cuda/operator/intrinsics/wgmma.py index 87666db58183..1ec52a5e4edb 100644 --- a/python/tvm/tirx/operator/intrinsics/cuda/wgmma.py +++ b/python/tvm/backend/cuda/operator/intrinsics/wgmma.py @@ -25,7 +25,7 @@ import tvm -from .._schema import device_intrinsic +from ._schema import device_intrinsic from .registry import CODEGEN_REGISTRY, register_codegen from .types import PTXDataType from .utils import parse_str diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/__init__.py b/python/tvm/backend/cuda/operator/tile_primitive/__init__.py similarity index 89% rename from python/tvm/tirx/operator/tile_primitive/cuda/__init__.py rename to python/tvm/backend/cuda/operator/tile_primitive/__init__.py index cea930c362d1..c4f37b02536a 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/__init__.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/__init__.py @@ -16,5 +16,9 @@ # under the License. from .copy import * +from .copy_async import * from .elementwise import * +from .gemm import * +from .gemm_async import * +from .permute_layout import * from .reduction import * diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/common.py b/python/tvm/backend/cuda/operator/tile_primitive/common.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/common.py rename to python/tvm/backend/cuda/operator/tile_primitive/common.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/__init__.py b/python/tvm/backend/cuda/operator/tile_primitive/copy/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/copy/__init__.py rename to python/tvm/backend/cuda/operator/tile_primitive/copy/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/_common.py b/python/tvm/backend/cuda/operator/tile_primitive/copy/_common.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/copy/_common.py rename to python/tvm/backend/cuda/operator/tile_primitive/copy/_common.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/_swizzle_iter.py b/python/tvm/backend/cuda/operator/tile_primitive/copy/_swizzle_iter.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/copy/_swizzle_iter.py rename to python/tvm/backend/cuda/operator/tile_primitive/copy/_swizzle_iter.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/fallback.py b/python/tvm/backend/cuda/operator/tile_primitive/copy/fallback.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/copy/fallback.py rename to python/tvm/backend/cuda/operator/tile_primitive/copy/fallback.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/gmem_smem.py b/python/tvm/backend/cuda/operator/tile_primitive/copy/gmem_smem.py similarity index 99% rename from python/tvm/tirx/operator/tile_primitive/cuda/copy/gmem_smem.py rename to python/tvm/backend/cuda/operator/tile_primitive/copy/gmem_smem.py index aee24c62e4f5..727f7a18df77 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy/gmem_smem.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/copy/gmem_smem.py @@ -87,7 +87,7 @@ def _divides_thread_cnt( def _is_gmem_smem(op_call: TilePrimitiveCall, sctx: DispatchContext) -> tuple[bool, str | None]: - if not sctx.is_cuda(): + if not sctx.is_target("cuda"): return False, "non-cuda target" if sctx.scope_kind not in ("thread", "warp", "warpgroup", "cta"): return False, f"unsupported exec_scope {sctx.scope_kind}" diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/ld_stmatrix.py b/python/tvm/backend/cuda/operator/tile_primitive/copy/ld_stmatrix.py similarity index 99% rename from python/tvm/tirx/operator/tile_primitive/cuda/copy/ld_stmatrix.py rename to python/tvm/backend/cuda/operator/tile_primitive/copy/ld_stmatrix.py index 75b0c9015b55..96a4545c2b92 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy/ld_stmatrix.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/copy/ld_stmatrix.py @@ -61,7 +61,7 @@ def key(p): def _is_ldstmatrix(op_call: TilePrimitiveCall, sctx: DispatchContext) -> tuple[bool, str | None]: - if not sctx.is_cuda(): + if not sctx.is_target("cuda"): return False, "non-cuda target" if sctx.scope_kind not in ("warp", "warpgroup", "cta"): return False, f"unsupported exec_scope {sctx.scope_kind} (need warp, warpgroup, or cta)" diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/reg.py b/python/tvm/backend/cuda/operator/tile_primitive/copy/reg.py similarity index 99% rename from python/tvm/tirx/operator/tile_primitive/cuda/copy/reg.py rename to python/tvm/backend/cuda/operator/tile_primitive/copy/reg.py index 5ea8d40e9382..3f30a83c038f 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy/reg.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/copy/reg.py @@ -177,7 +177,7 @@ def _s_side_slice_ok(op_call: TilePrimitiveCall) -> tuple[bool, str | None]: def _is_reg_copy(op_call: TilePrimitiveCall, sctx: DispatchContext) -> tuple[bool, str | None]: - if not sctx.is_cuda(): + if not sctx.is_target("cuda"): return False, "non-cuda target" if sctx.scope_kind not in ("thread", "warp", "warpgroup", "cta"): return False, f"unsupported exec_scope {sctx.scope_kind}" diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/utils.py b/python/tvm/backend/cuda/operator/tile_primitive/copy/utils.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/copy/utils.py rename to python/tvm/backend/cuda/operator/tile_primitive/copy/utils.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/__init__.py b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/copy_async/__init__.py rename to python/tvm/backend/cuda/operator/tile_primitive/copy_async/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/dsmem.py b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/dsmem.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/copy_async/dsmem.py rename to python/tvm/backend/cuda/operator/tile_primitive/copy_async/dsmem.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/ldgsts.py b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/ldgsts.py similarity index 99% rename from python/tvm/tirx/operator/tile_primitive/cuda/copy_async/ldgsts.py rename to python/tvm/backend/cuda/operator/tile_primitive/copy_async/ldgsts.py index 8c86f75ac8c4..e57228db6b2e 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/ldgsts.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/ldgsts.py @@ -89,7 +89,7 @@ def _divides_thread_cnt_ldgsts( def _is_ldgsts(op_call: TilePrimitiveCall, sctx: DispatchContext) -> tuple[bool, str | None]: - if not sctx.is_cuda(): + if not sctx.is_target("cuda"): return False, "non-cuda target" if sctx.scope_kind not in ("thread", "warp", "warpgroup", "cta"): return False, f"unsupported exec_scope {sctx.scope_kind}" diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_cp.py b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_cp.py similarity index 99% rename from python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_cp.py rename to python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_cp.py index 3a9d81947804..5cccc307f573 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_cp.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_cp.py @@ -408,7 +408,7 @@ def copy_smem_tmem_impl(op_call: TilePrimitiveCall, sctx: DispatchContext) -> Pr desc_buf = _get_or_create_desc(sctx, s_buf, LDO_field, SDO_field, sw) t_addr = t_buf.allocated_addr - from tvm.tirx.operator.tile_primitive.cuda.common import smem_desc_add_16B_offset + from tvm.backend.cuda.operator.tile_primitive.common import smem_desc_add_16B_offset # Flatten the N-D middle iteration into a single T.unroll. Each iteration's # per-dim index is (flat // stride) % extent, summed into the t/s offsets. diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py rename to python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tma.py b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tma.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tma.py rename to python/tvm/backend/cuda/operator/tile_primitive/copy_async/tma.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/utils.py b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/utils.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/copy_async/utils.py rename to python/tvm/backend/cuda/operator/tile_primitive/copy_async/utils.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/__init__.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/__init__.py similarity index 95% rename from python/tvm/tirx/operator/tile_primitive/cuda/elementwise/__init__.py rename to python/tvm/backend/cuda/operator/tile_primitive/elementwise/__init__.py index 872cad0867ee..576dc437b91d 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/__init__.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/__init__.py @@ -30,7 +30,7 @@ from .register import * # Suppress submodule-attribute leakage. Without an explicit ``__all__`` here, -# ``from tvm.tirx.operator.tile_primitive.cuda.elementwise import *`` (run by +# ``from tvm.backend.cuda.operator.tile_primitive.elementwise import *`` (run by # tile_primitive/__init__.py) re-exports the implicit submodule attributes # (``ops``, ``reg``, ``smem``, ``vec_emit``) — and ``ops`` in particular # shadows the top-level ``tile_primitive/ops.py`` (BinaryReduce / UnaryReduce diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/_common.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/_common.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/elementwise/_common.py rename to python/tvm/backend/cuda/operator/tile_primitive/elementwise/_common.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/__init__.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/__init__.py rename to python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/binary.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/binary.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/binary.py rename to python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/binary.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/cast.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/cast.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/cast.py rename to python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/cast.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/fma.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/fma.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/fma.py rename to python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/fma.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/unary.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/unary.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/elementwise/ops/unary.py rename to python/tvm/backend/cuda/operator/tile_primitive/elementwise/ops/unary.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/reg.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py similarity index 99% rename from python/tvm/tirx/operator/tile_primitive/cuda/elementwise/reg.py rename to python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py index 063e4c397903..eddf9f3d8eac 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/reg.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py @@ -110,7 +110,7 @@ def is_reg_ewise(spec): """Predicate factory: dispatch accepted iff all operands in ``local`` scope.""" def check(op_call: TilePrimitiveCall, sctx: DispatchContext) -> tuple[bool, str | None]: - if not sctx.is_cuda: + if not sctx.is_target("cuda"): return False, "non-cuda target" if sctx.scope_kind not in ("thread", "warp", "warpgroup", "cta"): return False, f"unsupported scope {sctx.scope_kind}" diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/register.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/register.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/elementwise/register.py rename to python/tvm/backend/cuda/operator/tile_primitive/elementwise/register.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/smem.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/smem.py similarity index 99% rename from python/tvm/tirx/operator/tile_primitive/cuda/elementwise/smem.py rename to python/tvm/backend/cuda/operator/tile_primitive/elementwise/smem.py index 3ac7405d7734..c2cb5469be10 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/smem.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/smem.py @@ -60,7 +60,7 @@ def is_smem_ewise(spec): """Predicate factory: dispatch accepted iff all operands in ``shared*``.""" def check(op_call: TilePrimitiveCall, sctx: DispatchContext) -> tuple[bool, str | None]: - if not sctx.is_cuda: + if not sctx.is_target("cuda"): return False, "non-cuda target" if sctx.scope_kind not in ("thread", "warp", "warpgroup", "cta"): return False, f"unsupported scope {sctx.scope_kind}" diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/__init__.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/__init__.py rename to python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/binary_f32x2.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/binary_f32x2.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/binary_f32x2.py rename to python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/binary_f32x2.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/cast_vec2.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/cast_vec2.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/cast_vec2.py rename to python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/cast_vec2.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/fma_f32x2.py b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/fma_f32x2.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/elementwise/vec_emit/fma_f32x2.py rename to python/tvm/backend/cuda/operator/tile_primitive/elementwise/vec_emit/fma_f32x2.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/exec_scope_utils.py b/python/tvm/backend/cuda/operator/tile_primitive/exec_scope_utils.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/exec_scope_utils.py rename to python/tvm/backend/cuda/operator/tile_primitive/exec_scope_utils.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/gemm/__init__.py b/python/tvm/backend/cuda/operator/tile_primitive/gemm/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/gemm/__init__.py rename to python/tvm/backend/cuda/operator/tile_primitive/gemm/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/gemm/mma_m16n8k_.py b/python/tvm/backend/cuda/operator/tile_primitive/gemm/mma_m16n8k_.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/gemm/mma_m16n8k_.py rename to python/tvm/backend/cuda/operator/tile_primitive/gemm/mma_m16n8k_.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/__init__.py b/python/tvm/backend/cuda/operator/tile_primitive/gemm_async/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/__init__.py rename to python/tvm/backend/cuda/operator/tile_primitive/gemm_async/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py b/python/tvm/backend/cuda/operator/tile_primitive/gemm_async/tcgen05.py similarity index 83% rename from python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py rename to python/tvm/backend/cuda/operator/tile_primitive/gemm_async/tcgen05.py index c19bcda622e8..0afa042b5e19 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/gemm_async/tcgen05.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/gemm_async/tcgen05.py @@ -46,7 +46,12 @@ from ..common import get_st_extent, smem_desc_add_16B_offset from ..exec_scope_utils import single_thread -from ..tma_utils import SwizzleMode, mma_atom_layout, mma_atom_shape +from ..tma_utils import ( + SwizzleMode, + get_swizzle_mode_from_layout, + mma_atom_layout, + mma_atom_shape, +) # Mirror of ``format_map`` in the dense ``encode_instr_descriptor`` codegen # (``python/tvm/tirx/operator/intrinsics/cuda/tcgen05.py``). Used to fold the @@ -516,80 +521,152 @@ def compute_canonical_params(buf, buf_region, dtype, is_transposed): Tuple of (swizzle_mode, ldo, sdo, is_mn_major). """ region = list(buf_region.region) - slice_layout = buf.layout.slice(buf.shape, region) + + def _match(slice_layout, shape_2d): + """Match ``slice_layout`` (of ``shape_2d``) against the swizzle atoms. + + Returns ``(swizzle_mode, ldo, sdo, is_mn_major)`` or ``None``. + """ + + def _try_atom(atom, atom_shape): + if any(s % a != 0 for s, a in zip(shape_2d, atom_shape)): + return None + atom_size = functools.reduce(operator.mul, atom_shape, 1) + tiler = atom.is_tile_inner(slice_layout, shape_2d, atom_shape) + if tiler is None: + return None + tiler_shape = [s // a for s, a in zip(shape_2d, atom_shape)] + tiler_grouped, seps = tiler.canonicalize().group(tiler_shape) + elem_per_128b = 128 // tvm.DataType(dtype).bits + ldo = (tiler_grouped.shard[-1].stride * atom_size) // elem_per_128b + sdo = (tiler_grouped.shard[-2].stride * atom_size) // elem_per_128b + return mode, ldo, sdo + + for mode in ( + SwizzleMode.SWIZZLE_128B_ATOM, + SwizzleMode.SWIZZLE_64B_ATOM, + SwizzleMode.SWIZZLE_32B_ATOM, + ): + swizzle_atom = mma_atom_layout(dtype, mode) + base_shape = mma_atom_shape(dtype, mode) # [8, T*s] + swapped_shape = [base_shape[1], base_shape[0]] # [T*s, 8] + + # MN-major atom: compose SwizzleLayout with stride-reversed TileLayout + # so the first dim (T*s) is contiguous instead of the second. + # Needed when the penultimate dim is physically contiguous. + mn_tile = TileLayout(S[tuple(swapped_shape) : (1, swapped_shape[0])]) + mn_atom = ComposeLayout(swizzle_atom, mn_tile) + + # Determine K-major vs MN-major based on which dim is contiguous. + # K-major: K dim contiguous (last dim for [MN,K], first dim for [K,MN]) + # MN-major: MN dim contiguous + # + # The plain swizzle_atom has last dim contiguous. + # The mn_atom has first dim contiguous. + # + # For non-transposed [MN, K]: K is last dim + # - K-major = swizzle_atom with [8, T*s] (K contiguous in last dim) + # - MN-major = mn_atom with [T*s, 8] (MN contiguous in first dim) + # For transposed [K, MN]: MN is last dim + # - K-major = mn_atom with [T*s, 8] (K contiguous in first dim) + # - MN-major = swizzle_atom with [8, T*s] (MN contiguous in last dim) + if is_transposed: + candidates = [ + (False, mn_atom, swapped_shape), # K-major: K in first dim + (True, swizzle_atom, base_shape), # MN-major: MN in last dim + ] + else: + candidates = [ + (False, swizzle_atom, base_shape), # K-major: K in last dim + (True, mn_atom, swapped_shape), # MN-major: MN in first dim + ] + + for is_mn_major, atom, atom_shape in candidates: + result = _try_atom(atom, atom_shape) + if result is not None: + sw, ldo_val, sdo_val = result + # shard[-1] = last-dim groups, shard[-2] = first-dim groups. + # LBO strides MN-groups for MN-major, K-groups for K-major. + # Non-transposed [MN,K]: last=K, first=MN → swap for MN-major + # Transposed [K,MN]: last=MN, first=K → swap for K-major + if is_mn_major != is_transposed: + ldo_val, sdo_val = sdo_val, ldo_val + return sw, ldo_val, sdo_val, is_mn_major + return None + + # The MMA SMEM descriptor describes the buffer's *physical* swizzle, which + # spans whole atoms and is a property of the buffer -- not of any sub-tile. + # Round the contiguous (innermost) axis up to a swizzle-atom multiple before + # matching, so the descriptor is always derived from a whole number of atoms: + # * for an already atom-aligned region (full-K, stride-axis P@V slices, + # non-swizzled buffers) this is a no-op -- desc_region == region, so the + # derived (swizzle, ldo, sdo, is_mn_major) are identical to matching the + # region directly; + # * for a sub-atom contiguous slice (fine K-major split-K, e.g. + # Asmem[..., :, lo:hi]) it rounds up to the smallest atom count covering + # the slice and describes it from the buffer origin. + # Either way the actual [lo:hi] range is addressed by K_iters + the per-MMA + # -tile 16B offset (from the *sliced* region in _a_operand / _b_desc_val), + # not by this descriptor. Verified hardware-correct for every MMA_K-aligned + # contiguous slice -- this is what enables fine K-major split-K. + cax = len(region) - 1 # innermost (contiguous) axis + elem_per_16b = 128 // DataType(dtype).bits + phys_mode = get_swizzle_mode_from_layout(buf.layout) + desc_region = list(region) + if phys_mode in ( + SwizzleMode.SWIZZLE_128B_ATOM, + SwizzleMode.SWIZZLE_64B_ATOM, + SwizzleMode.SWIZZLE_32B_ATOM, + ): + atom_inner = mma_atom_shape(dtype, phys_mode)[-1] + contig = int(region[cax].extent) + rounded = ((contig + atom_inner - 1) // atom_inner) * atom_inner + if rounded != contig: + # Sub-atom contiguous slice. The per-tile 16B offset that locates the + # slice start must be exact, so the start has to sit on a 16B + # (= elem_per_16b element) boundary; otherwise reject rather than + # silently mis-address. + if not analyzer.can_prove_equal( + tvm.tirx.floormod(region[cax].min, elem_per_16b), 0 + ): + raise ValueError( + f"gemm_async: contiguous-axis slice start {region[cax].min} is not " + f"16B-aligned (={elem_per_16b} elements, dtype {dtype}). A sub-atom " + f"contiguous slice must start on a 16B boundary; otherwise keep that " + f"axis full and split on a stride/outer axis instead (lay the operand " + f"out MN-major)." + ) + desc_region[cax] = tvm.ir.Range.from_min_extent(0, rounded) + + slice_layout = buf.layout.slice(buf.shape, desc_region) # Strip unit dims to get the 2D matrix shape. - shape_2d = [int(r.extent) for r in region if int(r.extent) != 1] + shape_2d = [int(r.extent) for r in desc_region if int(r.extent) != 1] assert len(shape_2d) == 2, ( - f"Expected exactly 2 non-unit dims in region {[int(r.extent) for r in region]}" + f"Expected exactly 2 non-unit dims in region {[int(r.extent) for r in desc_region]}" ) - - def _try_atom(atom, atom_shape): - if any(s % a != 0 for s, a in zip(shape_2d, atom_shape)): - return None - atom_size = functools.reduce(operator.mul, atom_shape, 1) - tiler = atom.is_tile_inner(slice_layout, shape_2d, atom_shape) - if tiler is None: - return None - tiler_shape = [s // a for s, a in zip(shape_2d, atom_shape)] - tiler_grouped, seps = tiler.canonicalize().group(tiler_shape) - elem_per_128b = 128 // tvm.DataType(dtype).bits - ldo = (tiler_grouped.shard[-1].stride * atom_size) // elem_per_128b - sdo = (tiler_grouped.shard[-2].stride * atom_size) // elem_per_128b - return mode, ldo, sdo - - for mode in ( + result = _match(slice_layout, shape_2d) + if result is not None: + return result + + # Genuinely unsupported: the layout doesn't tile any swizzle atom even at + # full atom width. Actionable error (the old generic "no swizzle mode" + # message read like a hard limit and was mistaken for one). + hint = "" + if phys_mode in ( SwizzleMode.SWIZZLE_128B_ATOM, SwizzleMode.SWIZZLE_64B_ATOM, SwizzleMode.SWIZZLE_32B_ATOM, ): - swizzle_atom = mma_atom_layout(dtype, mode) - base_shape = mma_atom_shape(dtype, mode) # [8, T*s] - swapped_shape = [base_shape[1], base_shape[0]] # [T*s, 8] - - # MN-major atom: compose SwizzleLayout with stride-reversed TileLayout - # so the first dim (T*s) is contiguous instead of the second. - # Needed when the penultimate dim is physically contiguous. - mn_tile = TileLayout(S[tuple(swapped_shape) : (1, swapped_shape[0])]) - mn_atom = ComposeLayout(swizzle_atom, mn_tile) - - # Determine K-major vs MN-major based on which dim is contiguous. - # K-major: K dim contiguous (last dim for [MN,K], first dim for [K,MN]) - # MN-major: MN dim contiguous - # - # The plain swizzle_atom has last dim contiguous. - # The mn_atom has first dim contiguous. - # - # For non-transposed [MN, K]: K is last dim - # - K-major = swizzle_atom with [8, T*s] (K contiguous in last dim) - # - MN-major = mn_atom with [T*s, 8] (MN contiguous in first dim) - # For transposed [K, MN]: MN is last dim - # - K-major = mn_atom with [T*s, 8] (K contiguous in first dim) - # - MN-major = swizzle_atom with [8, T*s] (MN contiguous in last dim) - if is_transposed: - candidates = [ - (False, mn_atom, swapped_shape), # K-major: K in first dim - (True, swizzle_atom, base_shape), # MN-major: MN in last dim - ] - else: - candidates = [ - (False, swizzle_atom, base_shape), # K-major: K in last dim - (True, mn_atom, swapped_shape), # MN-major: MN in first dim - ] - - for is_mn_major, atom, atom_shape in candidates: - result = _try_atom(atom, atom_shape) - if result is not None: - sw, ldo_val, sdo_val = result - # shard[-1] = last-dim groups, shard[-2] = first-dim groups. - # LBO strides MN-groups for MN-major, K-groups for K-major. - # Non-transposed [MN,K]: last=K, first=MN → swap for MN-major - # Transposed [K,MN]: last=MN, first=K → swap for K-major - if is_mn_major != is_transposed: - ldo_val, sdo_val = sdo_val, ldo_val - return sw, ldo_val, sdo_val, is_mn_major - + atom_inner = mma_atom_shape(dtype, phys_mode)[-1] + atom_bytes = atom_inner * DataType(dtype).bits // 8 + hint = ( + f" The buffer is physically {phys_mode.name}-swizzled (contiguous atom " + f"= {atom_inner} elements / {atom_bytes} B); the region's layout does not " + f"tile it." + ) raise ValueError( - f"No compatible swizzle mode found for dtype {dtype} with region shape {shape_2d}" + f"gemm_async: no MMA SMEM descriptor matches region shape {shape_2d} " + f"for dtype {dtype}.{hint}" ) if a_is_tmem: diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/gemm_utils.py b/python/tvm/backend/cuda/operator/tile_primitive/gemm_utils.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/gemm_utils.py rename to python/tvm/backend/cuda/operator/tile_primitive/gemm_utils.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/layout_utils.py b/python/tvm/backend/cuda/operator/tile_primitive/layout_utils.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/layout_utils.py rename to python/tvm/backend/cuda/operator/tile_primitive/layout_utils.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/__init__.py b/python/tvm/backend/cuda/operator/tile_primitive/permute_layout/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/__init__.py rename to python/tvm/backend/cuda/operator/tile_primitive/permute_layout/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/warp_xor_swizzle.py b/python/tvm/backend/cuda/operator/tile_primitive/permute_layout/warp_xor_swizzle.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/permute_layout/warp_xor_swizzle.py rename to python/tvm/backend/cuda/operator/tile_primitive/permute_layout/warp_xor_swizzle.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/__init__.py b/python/tvm/backend/cuda/operator/tile_primitive/reduction/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/reduction/__init__.py rename to python/tvm/backend/cuda/operator/tile_primitive/reduction/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/local.py b/python/tvm/backend/cuda/operator/tile_primitive/reduction/local.py similarity index 99% rename from python/tvm/tirx/operator/tile_primitive/cuda/reduction/local.py rename to python/tvm/backend/cuda/operator/tile_primitive/reduction/local.py index b05618f15371..3316aaf2a1fb 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/local.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/reduction/local.py @@ -66,10 +66,10 @@ from tvm.tirx import BufferRegion, PrimFunc from tvm.tirx.layout import TileLayout, laneid from tvm.tirx.operator.tile_primitive import DispatchContext, fail +from tvm.tirx.operator.tile_primitive.common import ReduceOpType from tvm.tirx.operator.tile_primitive.dispatcher import predicate, register_dispatch from tvm.tirx.stmt import TilePrimitiveCall -from ...common import ReduceOpType from ..common import get_indices, get_st_extent from ..layout_utils import get_local_region, get_sublayout_from_region from .utils import ( @@ -156,7 +156,7 @@ def validate_reduction_local( dst_br, src_br = op.output, op.input dst, src = dst_br.buffer, src_br.buffer - if not (src.scope() == "local" and dst.scope() == "local" and sctx.is_cuda()): + if not (src.scope() == "local" and dst.scope() == "local" and sctx.is_target("cuda")): return False, "expected local scope and CUDA target" if src.dtype != dst.dtype: return False, f"dtype mismatch: src={src.dtype} dst={dst.dtype}" diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/shared.py b/python/tvm/backend/cuda/operator/tile_primitive/reduction/shared.py similarity index 99% rename from python/tvm/tirx/operator/tile_primitive/cuda/reduction/shared.py rename to python/tvm/backend/cuda/operator/tile_primitive/reduction/shared.py index 8bee09ecc3f0..38986df86bb0 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/shared.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/reduction/shared.py @@ -61,10 +61,10 @@ from tvm.script import tirx as T from tvm.tirx import BufferRegion, PrimFunc from tvm.tirx.operator.tile_primitive import DispatchContext, fail +from tvm.tirx.operator.tile_primitive.common import ReduceOpType from tvm.tirx.operator.tile_primitive.dispatcher import predicate, register_dispatch from tvm.tirx.stmt import TilePrimitiveCall -from ...common import ReduceOpType from ..common import get_indices, get_st_extent, next_power_of_2 from .utils import ( _analyze_axes, diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/sm100_packed.py b/python/tvm/backend/cuda/operator/tile_primitive/reduction/sm100_packed.py similarity index 99% rename from python/tvm/tirx/operator/tile_primitive/cuda/reduction/sm100_packed.py rename to python/tvm/backend/cuda/operator/tile_primitive/reduction/sm100_packed.py index 5b8540ecac25..fd506fa14bf9 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/sm100_packed.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/reduction/sm100_packed.py @@ -48,10 +48,10 @@ from tvm.script import tirx as T from tvm.tirx import BufferRegion, PrimFunc from tvm.tirx.operator.tile_primitive import DispatchContext +from tvm.tirx.operator.tile_primitive.common import ReduceOpType from tvm.tirx.operator.tile_primitive.dispatcher import predicate, register_dispatch from tvm.tirx.stmt import TilePrimitiveCall -from ...common import ReduceOpType from ..common import sm_version_ok from ..exec_scope_utils import exec_scope_ok from .utils import ( diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/utils.py b/python/tvm/backend/cuda/operator/tile_primitive/reduction/utils.py similarity index 97% rename from python/tvm/tirx/operator/tile_primitive/cuda/reduction/utils.py rename to python/tvm/backend/cuda/operator/tile_primitive/reduction/utils.py index b53b5d181068..6248cddea117 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/reduction/utils.py +++ b/python/tvm/backend/cuda/operator/tile_primitive/reduction/utils.py @@ -25,9 +25,9 @@ from tvm.script import tirx as T from tvm.tirx import BufferRegion from tvm.tirx.operator.tile_primitive import DispatchContext +from tvm.tirx.operator.tile_primitive.common import ReduceOpType from tvm.tirx.stmt import TilePrimitiveCall -from ...common import ReduceOpType from ..common import match_scope reduce_op_table = { @@ -250,7 +250,12 @@ def _local_scope_match(op: TilePrimitiveCall, sctx: DispatchContext): op = TilePrimitiveCall.downcast(op) src, dst = op.input.buffer, op.output.buffer ok = all( - [src.scope() == "local", dst.scope() == "local", src.dtype == dst.dtype, sctx.is_cuda()] + [ + src.scope() == "local", + dst.scope() == "local", + src.dtype == dst.dtype, + sctx.is_target("cuda"), + ] ) if not ok: return (False, "src/dst must be local scope with matching dtype on CUDA") diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/tma_utils.py b/python/tvm/backend/cuda/operator/tile_primitive/tma_utils.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/cuda/tma_utils.py rename to python/tvm/backend/cuda/operator/tile_primitive/tma_utils.py diff --git a/python/tvm/backend/cuda/script.py b/python/tvm/backend/cuda/script.py new file mode 100644 index 000000000000..a1148f9b67ee --- /dev/null +++ b/python/tvm/backend/cuda/script.py @@ -0,0 +1,571 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""CUDA TVMScript namespaces.""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from tvm.backend.cuda import op as _cuda_op +from tvm.tirx import Buffer +from tvm.tirx import op as _tir_op +from tvm.tirx.script.builder.ir import _dtype_forward, _op_wrapper + +# pylint: disable=protected-access + + +def _ptx_ldg32(reg, guard, addr, local_addr): + if isinstance(addr, Buffer): + addr = addr[0] + return _tir_op.call_intrin(reg.dtype, "tirx.ptx.ldg32", reg, guard, addr, local_addr) + + +_ptx_ldg32.__tir_op_name__ = "ptx.ldg32" + + +class PTXNamespace: + """The PTX instruction submodule.""" + + def __init__(self): + self.ldg32 = _ptx_ldg32 + self.ldmatrix = _dtype_forward(_cuda_op.ptx_ldmatrix) + # Apache-compatible variant. Same lowered intrinsic as + # ``ldmatrix`` but accepts the historical ``(trans, num, dtype, + # local_ptr, local_offset, smem_ptr, smem_offset)`` form. Coexists + # with the fork-native version so upstream-derived tests keep + # working without rewriting their tirx code. + self.ldmatrix_legacy = _dtype_forward(_cuda_op.ptx_ldmatrix_legacy) + self.stmatrix = _op_wrapper(_cuda_op.ptx_stmatrix) + self.setmaxnreg: Callable[..., Any] = _op_wrapper(_cuda_op.ptx_setmaxnreg) + self.elect_sync: Callable[..., Any] = _op_wrapper(_cuda_op.ptx_elect_sync) + self.fetch_register: Callable[..., Any] = _op_wrapper(_cuda_op.ptx_fetch_register) + self.ld = _op_wrapper(_cuda_op.ptx_ld) + self.ld_acquire = _op_wrapper(_cuda_op.ptx_ld_acquire) + self.ld_volatile = _op_wrapper(_cuda_op.ptx_ld_volatile) + self.ld_global_acquire = _op_wrapper(_cuda_op.ptx_ld_global_acquire) + self.red_scalar = _op_wrapper(_cuda_op.ptx_red_scalar) + self.atom_scalar = _op_wrapper(_cuda_op.ptx_atom_scalar) + self.prefetch_tensormap = _op_wrapper(_cuda_op.ptx_prefetch_tensormap) + self.mbarrier_test_wait_parity = _op_wrapper(_cuda_op.ptx_mbarrier_test_wait_parity) + self.cp_async_bulk_g2s_cta = _op_wrapper(_cuda_op.ptx_cp_async_bulk_g2s_cta) + self.cp_async_bulk_g2s_cluster = _op_wrapper(_cuda_op.ptx_cp_async_bulk_g2s_cluster) + self.cp_async_bulk_s2s_cluster = _op_wrapper(_cuda_op.ptx_cp_async_bulk_s2s_cluster) + self.cp_async_bulk_s2g = _op_wrapper(_cuda_op.ptx_cp_async_bulk_s2g) + self.st = _op_wrapper(_cuda_op.ptx_st) + self.st_bulk = _op_wrapper(_cuda_op.ptx_st_bulk) + self.fns_b32 = _op_wrapper(_cuda_op.ptx_fns_b32) + self.add_rn_f32_bf16 = _op_wrapper(_cuda_op.ptx_add_rn_f32_bf16) + self.mapa = _op_wrapper(_cuda_op.ptx_mapa) + self.map_shared_rank = _op_wrapper(_cuda_op.ptx_map_shared_rank) + self.any_sync = _op_wrapper(_cuda_op.ptx_any_sync) + # Math operations + self.exp2 = _op_wrapper(_cuda_op.ptx_exp2) + self.rcp = _op_wrapper(_cuda_op.ptx_rcp) + self.reduce3_min_f32 = _op_wrapper(_cuda_op.ptx_reduce3_min_f32) + self.reduce3_max_f32 = _op_wrapper(_cuda_op.ptx_reduce3_max_f32) + # add/sub/mul/fma DPS form: (d_addr, a, b[, c], *, rounding, ftz[, sat]) + self.add_f32 = _op_wrapper(_cuda_op.ptx_add_f32) + self.add_f32x2 = _op_wrapper(_cuda_op.ptx_add_f32x2) + self.add_f64 = _op_wrapper(_cuda_op.ptx_add_f64) + self.sub_f32 = _op_wrapper(_cuda_op.ptx_sub_f32) + self.sub_f32x2 = _op_wrapper(_cuda_op.ptx_sub_f32x2) + self.sub_f64 = _op_wrapper(_cuda_op.ptx_sub_f64) + self.mul_f32 = _op_wrapper(_cuda_op.ptx_mul_f32) + self.mul_f32x2 = _op_wrapper(_cuda_op.ptx_mul_f32x2) + self.mul_f64 = _op_wrapper(_cuda_op.ptx_mul_f64) + self.fma_f32 = _op_wrapper(_cuda_op.ptx_fma_f32) + self.fma_f32x2 = _op_wrapper(_cuda_op.ptx_fma_f32x2) + self.fma_f64 = _op_wrapper(_cuda_op.ptx_fma_f64) + self.max_f32 = _op_wrapper(_cuda_op.ptx_max_f32) + self.mma = MmaNamespace() + self.cp_async = CpAsyncNamespace() + self.wgmma = WgmmaNamespace() + self.mbarrier = MbarrierNamespace() + self.tcgen05 = Tcgen05Namespace() + self.bar = BarNamespace() + self.barrier = BarrierNamespace() + self.fence = FenceNamespace() + self.griddepcontrol = GriddepcontrolNamespace() + + +class MmaNamespace: + """The MMA instruction submodule.""" + + def __init__(self): + self.sp = _dtype_forward(_cuda_op.ptx_mma_sp) + # Apache-compatible variant of ptx_mma. Coexists with the + # fork-native ``__call__`` form (``T.ptx.mma(...)``). + self.legacy = _dtype_forward(_cuda_op.ptx_mma_legacy) + # __call__ corresponds to ptx_mma + self.__tir_call_op_name__ = "ptx_mma" + + def __call__(self, *args, **kwds): + return _dtype_forward(_cuda_op.ptx_mma)(*args, **kwds) + + +class CpAsyncNamespace: + """The CpAsync instruction submodule.""" + + def __init__(self): + self.commit_group = _op_wrapper(_cuda_op.ptx_cp_async_commit_group) + self.wait_group = _op_wrapper(_cuda_op.ptx_cp_async_wait_group) + # Legacy variant: takes (dst_ptr, dst_offset, src_ptr, src_offset, + # cp_size). Offsets are folded into the pointers; coexists with + # the fork-native ``__call__`` form. + self.legacy = _dtype_forward(_cuda_op.ptx_cp_async_legacy) + self.bulk = CpAsyncBulkNamespace() + self.mbarrier = CpAsyncMbarrierNamespace() + + def __call__(self, *args, **kwds): + # Accept the legacy 6-arg form ``(elem_dtype, dst, dst_off, src, + # src_off, cp_size)`` that the printer round-trips for the raw + # ``tirx.ptx_cp_async`` Call emitted by + # ``tvm.backend.cuda.transform.InjectPTXAsyncCopy``. The pass-emitted + # Call has 5 args (no ``tvm_access_ptr`` fold) and a + # per-element-dtype Call.dtype, so build it directly. + if len(args) == 6 and isinstance(args[0], str) and "dtype" not in kwds: + import tvm + + elem_dtype, dst, dst_off, src, src_off, cp_size = args + return tvm.tirx.Call( + tvm.DataType(elem_dtype), + tvm.ir.Op.get("tirx.ptx_cp_async"), + [dst, dst_off, src, src_off, cp_size], + ) + return _dtype_forward(_cuda_op.ptx_cp_async)(*args, **kwds) + + # __call__ corresponds to ptx_cp_async + __tir_call_op_name__ = "ptx_cp_async" + + +class CpAsyncBulkNamespace: + """The CpAsyncBulk instruction submodule.""" + + def __init__(self): + self.commit_group = _op_wrapper(_cuda_op.ptx_cp_async_bulk_commit_group) + self.wait_group = _op_wrapper(_cuda_op.ptx_cp_async_bulk_wait_group) + self.tensor = CpAsyncBulkTensorNamespace() + self.s2c = _op_wrapper(_cuda_op.ptx_cp_async_bulk_shared_to_cluster) + + def __call__(self, *args, **kwds): + return _dtype_forward(_cuda_op.ptx_cp_async_bulk)(*args, **kwds) + + # __call__ corresponds to ptx_cp_async_bulk + __tir_call_op_name__ = "ptx_cp_async_bulk" + + +class CpAsyncBulkTensorNamespace: + """The CpAsyncBulkTensor instruction submodule.""" + + def __init__(self): + self.g2c = _op_wrapper(_cuda_op.ptx_cp_async_bulk_tensor_global_to_cluster) + self.g2c_tile_gather4 = _op_wrapper( + _cuda_op.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster + ) + self.s2g = _op_wrapper(_cuda_op.ptx_cp_async_bulk_tensor_shared_to_global) + self.s2g_reduce = _op_wrapper(_cuda_op.ptx_cp_async_bulk_tensor_shared_to_global_reduce) + self.g2c_prefetch = _op_wrapper( + _cuda_op.ptx_cp_async_bulk_tensor_global_to_cluster_prefetch + ) + + @staticmethod + def g2c_bar_addr( + dim, + dst_ptr, + bar_addr, + tensormap_addr, + cta_mask, + cta_group, + cache_hint, + *coords, + cache_policy=None, + ): + _cuda_op._choice("cta_group", cta_group, _cuda_op._TCGEN05_CTA_GROUP) + cache_policy, has_cache_policy = _cuda_op._resolve_cache_policy(cache_hint, cache_policy) + return _tir_op.call_intrin( + "", + "tirx.ptx_cp_async_bulk_tensor_global_to_cluster", + dim, + dst_ptr, + bar_addr, + tensormap_addr, + cta_mask, + cta_group, + cache_policy, + int(has_cache_policy), + 1, + *coords, + ) + + @staticmethod + def g2c_tile_gather4_bar_addr( + dim, + dst_ptr, + bar_addr, + tensormap_addr, + cta_mask, + cta_group, + cache_hint, + *coords, + cache_policy=None, + ): + _cuda_op._choice("cta_group", cta_group, _cuda_op._TCGEN05_CTA_GROUP) + cache_policy, has_cache_policy = _cuda_op._resolve_cache_policy(cache_hint, cache_policy) + return _tir_op.call_intrin( + "", + "tirx.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster", + dim, + dst_ptr, + bar_addr, + tensormap_addr, + cta_mask, + cta_group, + cache_policy, + int(has_cache_policy), + 1, + *coords, + ) + + +class CpAsyncMbarrierNamespace: + """The CpAsyncMbarrier instruction submodule.""" + + def __init__(self): + self.arrive = _op_wrapper(_cuda_op.ptx_cp_async_mbarrier_arrive) + + +class WgmmaNamespace: + """The WGMMA instruction submodule.""" + + def __init__(self): + self.fence: Callable[..., Any] = _op_wrapper(_cuda_op.ptx_wgmma_fence) + self.commit_group = _op_wrapper(_cuda_op.ptx_wgmma_commit_group) + self.wait_group = _op_wrapper(_cuda_op.ptx_wgmma_wait_group) + self.noop_barrier = _op_wrapper(_cuda_op.ptx_wgmma_noop_barrier) + self.mma_async = WgmmaMmaAsyncNamespace() + self.encode_matrix_descriptor = _op_wrapper(_cuda_op.ptx_wgmma_encode_matrix_descriptor) + + +class WgmmaMmaAsyncNamespace: + """The WGMMA MMAAsync instruction submodule.""" + + def __init__(self): + self.ss = _op_wrapper(_cuda_op.ptx_wgmma_mma_async_ss) + self.rs = _op_wrapper(_cuda_op.ptx_wgmma_mma_async_rs) + + +class MbarrierNamespace: + """The Mbarrier instruction submodule.""" + + def __init__(self): + self.init = _op_wrapper(_cuda_op.ptx_mbarrier_init) + self.try_wait = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait) + self.try_wait_once = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait_once) + self.arrive = MbarrierArriveNamespace() + + +class MbarrierArriveNamespace: + """The Mbarrier Arrive instruction submodule.""" + + def __init__(self): + self.expect_tx = _op_wrapper(_cuda_op.ptx_mbarrier_arrive_expect_tx) + + def __call__(self, *args, **kwds): + return _op_wrapper(_cuda_op.ptx_mbarrier_arrive)(*args, **kwds) + + # __call__ corresponds to ptx_mbarrier_arrive + __tir_call_op_name__ = "ptx_mbarrier_arrive" + + +class Tcgen05Namespace: + """The Tcgen05 instruction submodule.""" + + def __init__(self): + self.alloc = _op_wrapper(_cuda_op.ptx_tcgen05_alloc) + self.dealloc = _op_wrapper(_cuda_op.ptx_tcgen05_dealloc) + self.relinquish_alloc_permit = _op_wrapper(_cuda_op.ptx_tcgen05_relinquish_alloc_permit) + self.encode_matrix_descriptor = _op_wrapper(_cuda_op.ptx_tcgen05_encode_matrix_descriptor) + self.encode_instr_descriptor = _op_wrapper(_cuda_op.ptx_tcgen05_encode_instr_descriptor) + self.encode_instr_descriptor_block_scaled = _op_wrapper( + _cuda_op.ptx_tcgen05_encode_instr_descriptor_block_scaled + ) + self.ld = _op_wrapper(_cuda_op.ptx_tcgen05_ld) + self.st = _op_wrapper(_cuda_op.ptx_tcgen05_st) + self.cp = _op_wrapper(_cuda_op.ptx_tcgen05_cp) + self.shift = _op_wrapper(_cuda_op.ptx_tcgen05_shift) + self.commit = _op_wrapper(_cuda_op.ptx_tcgen05_commit) + self.wait = Tcgen05WaitNamespace() + self.mma = Tcgen05MmaNamespace() + self.fence = Tcgen05FenceNamespace() + + +class Tcgen05FenceNamespace: + """The Tcgen05 Fence instruction submodule.""" + + def __init__(self): + self.before_thread_sync = _op_wrapper(_cuda_op.ptx_tcgen05_fence_before_thread_sync) + self.after_thread_sync = _op_wrapper(_cuda_op.ptx_tcgen05_fence_after_thread_sync) + + +class Tcgen05MmaNamespace: + """The Tcgen05 MMA instruction submodule.""" + + def __init__(self): + self.block_scale = _op_wrapper(_cuda_op.ptx_tcgen05_mma_block_scale) + self.sp = Tcgen05MmaSpNamespace() + + def __call__(self, *args, **kwds): + return _op_wrapper(_cuda_op.ptx_tcgen05_mma)(*args, **kwds) + + # __call__ corresponds to ptx_tcgen05_mma + __tir_call_op_name__ = "ptx_tcgen05_mma" + + +class Tcgen05MmaSpNamespace: + """Tcgen05 Sparse MMA instruction submodule.""" + + def __init__(self): + self.block_scale = _op_wrapper(_cuda_op.ptx_tcgen05_mma_sp_block_scale) + + def __call__(self, *args, **kwds): + return _op_wrapper(_cuda_op.ptx_tcgen05_mma_sp)(*args, **kwds) + + # __call__ corresponds to ptx_tcgen05_mma_sp + __tir_call_op_name__ = "ptx_tcgen05_mma_sp" + + +class Tcgen05WaitNamespace: + """The Tcgen05 Wait instruction submodule.""" + + def __init__(self): + self.ld = _op_wrapper(_cuda_op.ptx_tcgen05_wait_ld) + self.st = _op_wrapper(_cuda_op.ptx_tcgen05_wait_st) + + +class BarNamespace: + """The Bar instruction submodule.""" + + def __init__(self): + self.arrive = _op_wrapper(_cuda_op.ptx_bar_arrive) + self.sync = _op_wrapper(_cuda_op.ptx_bar_sync) + + +class BarrierNamespace: + """The Barrier instruction submodule.""" + + def __init__(self): + self.cluster = BarrierClusterNamespace() + + +class BarrierClusterNamespace: + """The BarrierCluster instruction submodule.""" + + def __init__(self): + self.arrive = _op_wrapper(_cuda_op.ptx_barrier_cluster_arrive) + self.wait = _op_wrapper(_cuda_op.ptx_barrier_cluster_wait) + + +class FenceNamespace: + """PTX fence instruction submodule.""" + + def __init__(self): + self.proxy_async = _op_wrapper(_cuda_op.ptx_fence_proxy_async) + self.mbarrier_init = _op_wrapper(_cuda_op.ptx_fence_mbarrier_init) + + def __call__(self, *args, **kwds): + return _op_wrapper(_cuda_op.ptx_fence)(*args, **kwds) + + __tir_call_op_name__ = "ptx_fence" + + +class GriddepcontrolNamespace: + """PTX griddepcontrol instruction submodule (sm_90+).""" + + def __init__(self): + self.wait = _op_wrapper(_cuda_op.ptx_griddepcontrol_wait) + self.launch_dependents = _op_wrapper(_cuda_op.ptx_griddepcontrol_launch_dependents) + + +class CUDANamespace: + """The CUDA intrinsics submodule.""" + + def __init__(self): + self.atomic_add = _op_wrapper(_cuda_op.cuda_atomic_add) + self.thread_fence = _op_wrapper(_cuda_op.cuda_thread_fence) + self.warpgroup_sync = _op_wrapper(_cuda_op.cuda_warpgroup_sync) + self.warp_sync = _op_wrapper(_cuda_op.cuda_warp_sync) + self.warp_reduce = _op_wrapper(_cuda_op.cuda_warp_reduce) + self.warp_sum = _op_wrapper(_cuda_op.cuda_warp_sum) + self.warp_max = _op_wrapper(_cuda_op.cuda_warp_max) + self.warp_min = _op_wrapper(_cuda_op.cuda_warp_min) + self.cta_reduce = _op_wrapper(_cuda_op.cuda_cta_reduce) + self.cta_sum = _op_wrapper(_cuda_op.cuda_cta_sum) + self.cta_max = _op_wrapper(_cuda_op.cuda_cta_max) + self.cta_min = _op_wrapper(_cuda_op.cuda_cta_min) + self.copy_bytes = _op_wrapper(_cuda_op.cuda_copy_bytes) + self.copy_128b = _op_wrapper(_cuda_op.cuda_copy_128b) + self.copy_64b = _op_wrapper(_cuda_op.cuda_copy_64b) + self.copy_32b = _op_wrapper(_cuda_op.cuda_copy_32b) + self.copy_16b = _op_wrapper(_cuda_op.cuda_copy_16b) + self.copy_8b = _op_wrapper(_cuda_op.cuda_copy_8b) + self.cta_sync = _op_wrapper(_cuda_op.cuda_cta_sync) + self.grid_sync = _op_wrapper(_cuda_op.cuda_grid_sync) + self.cluster_sync = _op_wrapper(_cuda_op.cuda_cluster_sync) + self.thread_rank = _op_wrapper(_cuda_op.cuda_thread_rank) + self.trap_when_assert_failed = _op_wrapper(_cuda_op.cuda_trap_when_assert_failed) + self.runtime_instr_desc = _op_wrapper(_cuda_op.cuda_runtime_instr_desc) + self.half2float = _op_wrapper(_cuda_op.cuda_half2float) + self.bfloat162float = _op_wrapper(_cuda_op.cuda_bfloat162float) + self.float22half2 = _op_wrapper(_cuda_op.cuda_float22half2) + self.half8tofloat8 = _op_wrapper(_cuda_op.cuda_half8tofloat8) + self.float8tohalf8 = _op_wrapper(_cuda_op.cuda_float8tohalf8) + self.syncthreads_and = _op_wrapper(_cuda_op.cuda_syncthreads_and) + self.syncthreads_or = _op_wrapper(_cuda_op.cuda_syncthreads_or) + self.nano_sleep = _op_wrapper(_cuda_op.cuda_nano_sleep) + self.atomic_cas = _op_wrapper(_cuda_op.cuda_atomic_cas) + self.func_call = _op_wrapper(_cuda_op.cuda_func_call) + self.printf = _op_wrapper(_cuda_op.cuda_printf) + self.ldg = _op_wrapper(_cuda_op.cuda_ldg) + self.get_tmem_addr = _op_wrapper(_cuda_op.cuda_get_tmem_addr) + self.cvta_generic_to_shared = _op_wrapper(_cuda_op.cuda_cvta_generic_to_shared) + self.smem_addr_from_uint64 = _op_wrapper(_cuda_op.cuda_smem_addr_from_uint64) + self.sm100_tma_2sm_mbarrier_addr = _op_wrapper(_cuda_op.cuda_sm100_tma_2sm_mbarrier_addr) + self.uint_as_float = _op_wrapper(_cuda_op.cuda_uint_as_float) + self.float_as_uint = _op_wrapper(_cuda_op.cuda_float_as_uint) + self.ballot_sync = _op_wrapper(_cuda_op.cuda_ballot_sync) + self.ffs_u32 = _op_wrapper(_cuda_op.cuda_ffs_u32) + self.reduce_add_sync_u32 = _op_wrapper(_cuda_op.cuda_reduce_add_sync_u32) + self.reduce_min_sync_u32 = _op_wrapper(_cuda_op.cuda_reduce_min_sync_u32) + self.clock64 = _op_wrapper(_cuda_op.cuda_clock64) + self.make_float2 = _op_wrapper(_cuda_op.cuda_make_float2) + self.float2_x = _op_wrapper(_cuda_op.cuda_float2_x) + self.float2_y = _op_wrapper(_cuda_op.cuda_float2_y) + self.fmul2_rn = _op_wrapper(_cuda_op.cuda_fmul2_rn) + self.fadd2_rn = _op_wrapper(_cuda_op.cuda_fadd2_rn) + self.float22bfloat162_rn = _op_wrapper(_cuda_op.cuda_float22bfloat162_rn) + self.float22bfloat162_rn_from_float2 = _op_wrapper( + _cuda_op.cuda_float22bfloat162_rn_from_float2 + ) + self.bfloat1622float2 = _op_wrapper(_cuda_op.cuda_bfloat1622float2) + self.hmin2 = _op_wrapper(_cuda_op.cuda_hmin2) + self.hmax2 = _op_wrapper(_cuda_op.cuda_hmax2) + self.fp8x4_e4m3_from_float4 = _op_wrapper(_cuda_op.cuda_fp8x4_e4m3_from_float4) + self.timer_init = _op_wrapper(_cuda_op.timer_init_cuda) + self.timer_start = _op_wrapper(_cuda_op.timer_start_cuda) + self.timer_end = _op_wrapper(_cuda_op.timer_end_cuda) + self.timer_finalize = _op_wrapper(_cuda_op.timer_finalize_cuda) + self.mma_store = _dtype_forward(_cuda_op.mma_store) + self.mma_fill = _dtype_forward(_cuda_op.mma_fill) + self.mma_store_legacy = _dtype_forward(_cuda_op.mma_store_legacy) + self.mma_fill_legacy = _dtype_forward(_cuda_op.mma_fill_legacy) + setattr(self, "__shfl_sync", self._shfl_sync) + setattr(self, "__shfl_up_sync", self._shfl_up_sync) + setattr(self, "__shfl_down_sync", self._shfl_down_sync) + setattr(self, "__shfl_xor_sync", self._shfl_xor_sync) + setattr(self, "__activemask", self._activemask) + + @staticmethod + def _shfl_sync(mask, var, lane, width): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin(var.dtype, "tirx.cuda.__shfl_sync", mask, var, lane, width) + + @staticmethod + def _shfl_up_sync(mask, var, delta, width): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin(var.dtype, "tirx.cuda.__shfl_up_sync", mask, var, delta, width) + + @staticmethod + def _shfl_down_sync(mask, var, delta, width): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin(var.dtype, "tirx.cuda.__shfl_down_sync", mask, var, delta, width) + + @staticmethod + def _shfl_xor_sync(mask, var, lane_mask, width): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin( + var.dtype, "tirx.cuda.__shfl_xor_sync", mask, var, lane_mask, width + ) + + @staticmethod + def _activemask(): + return _tir_op.call_intrin("uint32", "tirx.cuda.__activemask") + + +class NVSHMEMNamespace: + """The NVSHMEM intrinsics submodule.""" + + def __init__(self): + self.my_pe = _op_wrapper(_cuda_op.nvshmem_my_pe) + self.n_pes = _op_wrapper(_cuda_op.nvshmem_n_pes) + self.signal_op = _op_wrapper(_cuda_op.nvshmem_signal_op) + self.wait_until = _op_wrapper(_cuda_op.nvshmem_wait_until) + self.quiet = _op_wrapper(_cuda_op.nvshmem_quiet) + self.fence = _op_wrapper(_cuda_op.nvshmem_fence) + self.barrier_all = _op_wrapper(_cuda_op.nvshmem_barrier_all) + self.getmem_nbi = NVSHMEMGetMemNBINamespace() + self.putmem_nbi = NVSHMEMPutMemNBINamespace() + self.putmem_signal_nbi = NVSHMEMPutMemSignalNBINamespace() + + +class NVSHMEMGetMemNBINamespace: + """The NVSHMEM GetMemNBI intrinsics submodule.""" + + def __init__(self): + self.warp = _op_wrapper(_cuda_op.nvshmem_getmem_nbi_warp) + self.block = _op_wrapper(_cuda_op.nvshmem_getmem_nbi_block) + + def __call__(self, *args, **kwds): + return _op_wrapper(_cuda_op.nvshmem_getmem_nbi)(*args, **kwds) + + # __call__ corresponds to nvshmem_getmem_nbi + __tir_call_op_name__ = "nvshmem_getmem_nbi" + + +class NVSHMEMPutMemNBINamespace: + """The NVSHMEM PutMemNBI intrinsics submodule.""" + + def __init__(self): + self.warp = _op_wrapper(_cuda_op.nvshmem_putmem_nbi_warp) + self.block = _op_wrapper(_cuda_op.nvshmem_putmem_nbi_block) + + def __call__(self, *args, **kwds): + return _op_wrapper(_cuda_op.nvshmem_putmem_nbi)(*args, **kwds) + + # __call__ corresponds to nvshmem_putmem_nbi + __tir_call_op_name__ = "nvshmem_putmem_nbi" + + +class NVSHMEMPutMemSignalNBINamespace: + """The NVSHMEM PutMemSignalNBI intrinsics submodule.""" + + def __init__(self): + self.warp = _op_wrapper(_cuda_op.nvshmem_putmem_signal_nbi_warp) + self.block = _op_wrapper(_cuda_op.nvshmem_putmem_signal_nbi_block) + + def __call__(self, *args, **kwds): + return _op_wrapper(_cuda_op.nvshmem_putmem_signal_nbi)(*args, **kwds) + + # __call__ corresponds to nvshmem_putmem_signal_nbi + __tir_call_op_name__ = "nvshmem_putmem_signal_nbi" + + +__all__ = ["CUDANamespace", "NVSHMEMNamespace", "PTXNamespace"] diff --git a/python/tvm/target/tag_registry/cuda.py b/python/tvm/backend/cuda/target_tags.py similarity index 99% rename from python/tvm/target/tag_registry/cuda.py rename to python/tvm/backend/cuda/target_tags.py index 6b1bd9e8a8bd..87c83b378fef 100644 --- a/python/tvm/target/tag_registry/cuda.py +++ b/python/tvm/backend/cuda/target_tags.py @@ -16,7 +16,7 @@ # under the License. """NVIDIA CUDA target tags.""" -from .registry import register_tag +from tvm.target import register_tag def _register_cuda_tag(name, arch, shared_mem=49152, regs=65536, **extra): diff --git a/python/tvm/backend/hexagon/__init__.py b/python/tvm/backend/hexagon/__init__.py new file mode 100644 index 000000000000..3852e36ccc3b --- /dev/null +++ b/python/tvm/backend/hexagon/__init__.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Hexagon-owned backend hooks.""" + +from importlib import import_module + +_LAZY_SUBMODULES = {"target_tags"} + + +def register_backend(): + """Register Hexagon-owned Python semantics.""" + import_module(f"{__name__}.target_tags") + + +def __getattr__(name: str): + if name in _LAZY_SUBMODULES: + return import_module(f"{__name__}.{name}") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = ["register_backend", "target_tags"] diff --git a/python/tvm/target/tag_registry/hexagon.py b/python/tvm/backend/hexagon/target_tags.py similarity index 98% rename from python/tvm/target/tag_registry/hexagon.py rename to python/tvm/backend/hexagon/target_tags.py index a43c3fbd93ef..645567cabc26 100644 --- a/python/tvm/target/tag_registry/hexagon.py +++ b/python/tvm/backend/hexagon/target_tags.py @@ -16,7 +16,7 @@ # under the License. """Qualcomm Hexagon target tags.""" -from .registry import register_tag +from tvm.target import register_tag _ONE_MB = 2**20 diff --git a/python/tvm/backend/metal/__init__.py b/python/tvm/backend/metal/__init__.py new file mode 100644 index 000000000000..d42806433f73 --- /dev/null +++ b/python/tvm/backend/metal/__init__.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Metal-owned TIRx modules.""" + +from importlib import import_module + +_LAZY_SUBMODULES = {"op", "script", "target_tags"} + + +def register_backend(): + """Register Metal-owned Python semantics.""" + from tvm.tirx.script.builder import ir as builder_ir # pylint: disable=import-outside-toplevel + + for name, namespace in script_namespaces().items(): + builder_ir.register_script_namespace(name, namespace) + import_module(f"{__name__}.target_tags") + + +def script_namespaces(**_): + """Return Metal-owned TVMScript namespaces.""" + from .script import MetalNamespace # pylint: disable=import-outside-toplevel + + return {"metal": MetalNamespace()} + + +def script_namespace(**kwargs): + """Return the Metal TVMScript namespace object.""" + return script_namespaces(**kwargs)["metal"] + + +def __getattr__(name: str): + if name in _LAZY_SUBMODULES: + return import_module(f"{__name__}.{name}") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = [ + "op", + "register_backend", + "script", + "script_namespace", + "script_namespaces", + "target_tags", +] diff --git a/python/tvm/backend/metal/op.py b/python/tvm/backend/metal/op.py new file mode 100644 index 000000000000..f9485760206c --- /dev/null +++ b/python/tvm/backend/metal/op.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Metal-owned TIR intrinsic builders.""" + +from __future__ import annotations + +from tvm.tirx.op import call_intrin + + +def make_filled_simdgroup_matrix(d, index, value, col=8, row=8): + """Create a filled SIMDGroup matrix.""" + + return call_intrin("handle", "tirx.make_filled_simdgroup_matrix", d, index, value, col, row) + + +def simdgroup_load(d, index, ptr, stride, col=8, row=8, transpose_matrix=False): + """Load data from device or threadgroup memory to simdgroup.""" + + return call_intrin( + "handle", + "tirx.simdgroup_load", + d, + index, + ptr, + stride, + col, + row, + transpose_matrix, + ) + + +def simdgroup_store(d, index, ptr, stride, col=8, row=8, transpose_matrix=False): + """Store data from simdgroup to device or threadgroup memory.""" + + return call_intrin( + "handle", + "tirx.simdgroup_store", + d, + index, + ptr, + stride, + col, + row, + transpose_matrix, + ) + + +def simdgroup_multiply_accumulate(d, index_d, a, index_a, b, index_b, c, index_c): + """Multiply and accumulate two matrices in simdgroup.""" + + return call_intrin( + "handle", + "tirx.simdgroup_multiply_accumulate", + d, + index_d, + a, + index_a, + b, + index_b, + c, + index_c, + ) + + +__all__ = [ + "make_filled_simdgroup_matrix", + "simdgroup_load", + "simdgroup_multiply_accumulate", + "simdgroup_store", +] diff --git a/python/tvm/backend/metal/script.py b/python/tvm/backend/metal/script.py new file mode 100644 index 000000000000..7c5d45564ae9 --- /dev/null +++ b/python/tvm/backend/metal/script.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Metal TVMScript namespace.""" + +from __future__ import annotations + +from tvm.backend.metal import op as _metal_op +from tvm.tirx import Buffer +from tvm.tirx import op as _tir_op +from tvm.tirx.script.builder.ir import _op_wrapper + + +class MetalNamespace: + """The Metal intrinsics submodule.""" + + def __init__(self): + self.make_filled_simdgroup_matrix = _op_wrapper(_metal_op.make_filled_simdgroup_matrix) + self.simdgroup_load = _op_wrapper(_metal_op.simdgroup_load) + self.simdgroup_store = _op_wrapper(_metal_op.simdgroup_store) + self.simdgroup_multiply_accumulate = _op_wrapper(_metal_op.simdgroup_multiply_accumulate) + + @staticmethod + def simd_shuffle(var, lane): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle", var, lane) + + @staticmethod + def simd_shuffle_up(var, delta): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle_up", var, delta) + + @staticmethod + def simd_shuffle_down(var, delta): + if isinstance(var, Buffer): + var = var[0] + return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle_down", var, delta) + + +__all__ = ["MetalNamespace"] diff --git a/python/tvm/target/tag_registry/metal.py b/python/tvm/backend/metal/target_tags.py similarity index 94% rename from python/tvm/target/tag_registry/metal.py rename to python/tvm/backend/metal/target_tags.py index 6727db7c3046..6ee758216d76 100644 --- a/python/tvm/target/tag_registry/metal.py +++ b/python/tvm/backend/metal/target_tags.py @@ -16,14 +16,14 @@ # under the License. """Apple Metal GPU target tags.""" -from .registry import register_tag +from tvm.target import register_tag _METAL_HOST_TRIPLE = "arm64-apple-macos" def _register_metal_tag(name, max_threads, shared_mem, warp_size, mcpu): try: - from ..codegen import llvm_is_valid_cpu + from tvm.target.codegen import llvm_is_valid_cpu if not llvm_is_valid_cpu(mcpu, _METAL_HOST_TRIPLE): return diff --git a/python/tvm/backend/opencl/__init__.py b/python/tvm/backend/opencl/__init__.py new file mode 100644 index 000000000000..a80696e5900b --- /dev/null +++ b/python/tvm/backend/opencl/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""OpenCL-owned backend hooks.""" + + +def register_backend(): + """Register OpenCL-owned Python semantics.""" + return None + + +__all__ = ["register_backend"] diff --git a/python/tvm/backend/rocm/__init__.py b/python/tvm/backend/rocm/__init__.py new file mode 100644 index 000000000000..d7574e974a30 --- /dev/null +++ b/python/tvm/backend/rocm/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""ROCm-owned TIRx modules.""" + + +def register_backend(): + """Register ROCm-owned Python semantics.""" + return None + + +__all__ = ["register_backend"] diff --git a/python/tvm/backend/trn/__init__.py b/python/tvm/backend/trn/__init__.py new file mode 100644 index 000000000000..7650ea87d548 --- /dev/null +++ b/python/tvm/backend/trn/__init__.py @@ -0,0 +1,68 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Trainium-owned TIRx modules.""" + +from importlib import import_module + +_LAZY_SUBMODULES = {"layout", "op", "operator", "pipeline", "script", "target_tags", "transform"} + + +def register_backend(): + """Register Trainium-owned Python semantics.""" + from tvm.tirx import compilation_pipeline # pylint: disable=import-outside-toplevel + from tvm.tirx.script.builder import ir as builder_ir # pylint: disable=import-outside-toplevel + + for name, namespace in script_namespaces().items(): + builder_ir.register_script_namespace(name, namespace) + + import_module(f"{__name__}.operator.tile_primitive") + trn_pipeline = import_module(f"{__name__}.pipeline") + import_module(f"{__name__}.target_tags") + import_module(f"{__name__}.transform") + compilation_pipeline.register_tir_pipeline("trn", trn_pipeline.trn_pipeline) + + +def script_namespace(op_wrapper=None): + """Return the Trainium TVMScript namespace object.""" + from .script import NKINamespace # pylint: disable=import-outside-toplevel + + return NKINamespace(op_wrapper) + + +def script_namespaces(op_wrapper=None, **_): + """Return Trainium-owned TVMScript namespaces.""" + return {"nki": script_namespace(op_wrapper)} + + +def __getattr__(name: str): + if name in _LAZY_SUBMODULES: + return import_module(f"{__name__}.{name}") + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +__all__ = [ + "layout", + "op", + "operator", + "pipeline", + "register_backend", + "script", + "script_namespace", + "script_namespaces", + "target_tags", + "transform", +] diff --git a/python/tvm/backend/trn/layout.py b/python/tvm/backend/trn/layout.py new file mode 100644 index 000000000000..5d5f08959137 --- /dev/null +++ b/python/tvm/backend/trn/layout.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Trainium-specific TIRx layout helpers.""" + +from __future__ import annotations + +import functools +import operator +import re + +import tvm +from tvm.tirx.expr import PrimExpr +from tvm.tirx.layout import Axis, Iter, Layout, S, TileLayout + +_TRN_MEMORY_AXES = {"F", "P", "Bank"} +_PSUM_MAX_ELEM_PER_BANK = 512 + + +def is_trainium_layout(layout: Layout | None) -> bool: + """Return whether a layout uses only Trainium memory axes.""" + if not isinstance(layout, TileLayout): + return False + return not any( + iter.axis.is_memory() and iter.axis.name not in _TRN_MEMORY_AXES for iter in layout.shard + ) + + +def trainium_layout(annotation: str, shape: tuple[PrimExpr], is_psum: bool = False) -> TileLayout: + """Create a Trainium tile layout from a PF annotation string and logical shape.""" + analyzer = tvm.arith.Analyzer() + assert re.fullmatch(r"[PF]*", annotation), ( + f"annotation {annotation} must be a string of 'P' and 'F'" + ) + assert len(annotation) == len(shape), ( + f"annotation {annotation} and shape {shape} must have the same length" + ) + num_p_dim = annotation.count("P") + if num_p_dim == 1: + p_idx = annotation.index("P") + p_dim = shape[p_idx] + assert analyzer.can_prove(p_dim <= 128 or p_dim % 128 == 0), ( + f"There is only 1 P in the annotation. Partition size {p_dim} must be less than " + "or equal to 128 or a multiple of 128" + ) + if analyzer.can_prove(p_dim > 128): + annotation = "F" + annotation + shape = (p_dim // 128, *shape[:p_idx], 128, *shape[p_idx + 1 :]) + elif num_p_dim > 1: + p_dim_prod = functools.reduce( + operator.mul, [s for s, c in zip(shape, annotation) if c == "P"] + ) + assert analyzer.can_prove(p_dim_prod <= 128), ( + f"There are {num_p_dim} Ps in the annotation. Partition size {p_dim_prod} must be " + "less than or equal to 128" + ) + + f_shape = [s for s, c in zip(shape, annotation) if c == "F"] + p_shape = [s for s, c in zip(shape, annotation) if c == "P"] + f_strides = Layout._get_default_strides(f_shape, 1) # pylint: disable=protected-access + p_strides = Layout._get_default_strides(p_shape, 1) # pylint: disable=protected-access + f_tile_layout = TileLayout(S[tuple(f_shape) : tuple(s @ Axis.F for s in f_strides)]) + p_tile_layout = TileLayout(S[tuple(p_shape) : tuple(s @ Axis.P for s in p_strides)]) + result = [] + f_index = p_index = 0 + + for char in annotation: + if char == "F": + result.append(f_tile_layout.shard[f_index]) + f_index += 1 + else: + result.append(p_tile_layout.shard[p_index]) + p_index += 1 + if num_p_dim == 1 and analyzer.can_prove(p_dim > 128): + higher_p = result[0] + result = result[1:] + result = [*result[:p_idx], higher_p, *result[p_idx:]] + + res = TileLayout.from_iters(result, [], {}) + if is_psum: + res = to_psum_layout(res) + return res + + +def to_psum_layout(layout: TileLayout) -> TileLayout: + """Convert a Trainium sbuf layout to its psum physical-bank layout.""" + analyzer = tvm.arith.Analyzer() + shard = [] + for iter in layout.shard: + if iter.axis.name == "F": + if analyzer.can_prove(iter.stride % _PSUM_MAX_ELEM_PER_BANK == 0): + stride = analyzer.simplify(iter.stride // _PSUM_MAX_ELEM_PER_BANK) + shard.append(Iter(iter.extent, stride, Axis.get("Bank"))) + elif analyzer.can_prove(_PSUM_MAX_ELEM_PER_BANK % iter.stride == 0): + c = analyzer.simplify(_PSUM_MAX_ELEM_PER_BANK // iter.stride) + if analyzer.can_prove(iter.extent < c): + shard.append(iter) + elif analyzer.can_prove(iter.extent % c == 0): + shard.append(Iter(analyzer.simplify(iter.extent // c), 1, Axis.get("Bank"))) + shard.append(Iter(c, iter.stride, Axis.get("F"))) + else: + raise ValueError(f"layout {layout} can not be converted to psum layout") + else: + raise ValueError(f"layout {layout} can not be converted to psum layout") + else: + shard.append(iter) + return TileLayout.from_iters(shard, [], {}) + + +__all__ = ["is_trainium_layout", "to_psum_layout", "trainium_layout"] diff --git a/python/tvm/backend/trn/op.py b/python/tvm/backend/trn/op.py new file mode 100644 index 000000000000..d919e4fb8527 --- /dev/null +++ b/python/tvm/backend/trn/op.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Trainium-owned NKI intrinsic Python wrappers.""" + +from __future__ import annotations + +from tvm.tirx.op import call_intrin + + +def nki_load(res, data): + return call_intrin("", "tirx.nki_load", res, data) + + +def nki_store(res, data): + return call_intrin("", "tirx.nki_store", res, data) + + +def nki_tensor_copy(res, data): + return call_intrin("", "tirx.nki_tensor_copy", res, data) + + +def nki_matmul(res, lhs, rhs, accum=True): + return call_intrin("", "tirx.nki_matmul", res, lhs, rhs, accum) + + +def nki_activation(result, data, opcode, bias=0.0, scale=1.0): + return call_intrin("", "tirx.nki_activation", result, data, opcode, bias, scale) + + +def nki_reciprocal(result, data): + return call_intrin("", "tirx.nki_reciprocal", result, data) + + +def nki_tensorreduce(result, data, opcode, negate, *axes): + return call_intrin("", "tirx.nki_tensorreduce", result, data, opcode, negate, *axes) + + +def nki_tensortensor(result, operand0, operand1, opcode): + return call_intrin("", "tirx.nki_tensortensor", result, operand0, operand1, opcode) + + +def nki_tensorscalar(result, operand0, operand1, opcode, reverse=False): + return call_intrin("", "tirx.nki_tensorscalar", result, operand0, operand1, opcode, reverse) + + +def nki_memset(result, value): + return call_intrin("", "tirx.nki_memset", result, value) + + +def nki_activation_reduce(reduce_res, act_res, data, opcode, reduce_opcode, bias=0.0, scale=1.0): + return call_intrin( + "", + "tirx.nki_activation_reduce", + reduce_res, + act_res, + data, + opcode, + reduce_opcode, + bias, + scale, + ) + + +def nki_tensorscalar_reduce( + reduce_res, tensorscalar_res, operand0, operand1, opcode, reduce_opcode, reverse=False +): + return call_intrin( + "", + "tirx.nki_tensorscalar_reduce", + reduce_res, + tensorscalar_res, + operand0, + operand1, + opcode, + reduce_opcode, + reverse, + ) + + +def nki_identity(result, size): + return call_intrin("", "tirx.nki_identity", result, size) + + +def nki_scalar_tensor_tensor( + result, data, operand0, operand1, opcode0, opcode1, reverse0=False, reverse1=False +): + return call_intrin( + "", + "tirx.nki_scalar_tensor_tensor", + result, + data, + operand0, + operand1, + opcode0, + opcode1, + reverse0, + reverse1, + ) + + +def nki_scalar_tensor_scalar( + result, data, operand0, operand1, opcode0, opcode1, reverse0=False, reverse1=False +): + return call_intrin( + "", + "tirx.nki_scalar_tensor_scalar", + result, + data, + operand0, + operand1, + opcode0, + opcode1, + reverse0, + reverse1, + ) + + +def nki_affine_select(result, pred, true_value, false_value): + return call_intrin("", "tirx.nki_affine_select", result, pred, true_value, false_value) + + +__all__ = [ + "nki_activation", + "nki_activation_reduce", + "nki_affine_select", + "nki_identity", + "nki_load", + "nki_matmul", + "nki_memset", + "nki_reciprocal", + "nki_scalar_tensor_scalar", + "nki_scalar_tensor_tensor", + "nki_store", + "nki_tensor_copy", + "nki_tensorreduce", + "nki_tensorscalar", + "nki_tensorscalar_reduce", + "nki_tensortensor", +] diff --git a/python/tvm/backend/trn/operator/__init__.py b/python/tvm/backend/trn/operator/__init__.py new file mode 100644 index 000000000000..2d9b93d99852 --- /dev/null +++ b/python/tvm/backend/trn/operator/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Trainium backend operator package. + +Loaded by the Trainium backend registration hook. +""" + +__all__ = ["tile_primitive"] diff --git a/python/tvm/tirx/operator/tile_primitive/trn/__init__.py b/python/tvm/backend/trn/operator/tile_primitive/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/__init__.py rename to python/tvm/backend/trn/operator/tile_primitive/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/trn/binary/__init__.py b/python/tvm/backend/trn/operator/tile_primitive/binary/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/binary/__init__.py rename to python/tvm/backend/trn/operator/tile_primitive/binary/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/trn/binary/default.py b/python/tvm/backend/trn/operator/tile_primitive/binary/default.py similarity index 97% rename from python/tvm/tirx/operator/tile_primitive/trn/binary/default.py rename to python/tvm/backend/trn/operator/tile_primitive/binary/default.py index 3fa565b1f41d..9546c2f6e43f 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/binary/default.py +++ b/python/tvm/backend/trn/operator/tile_primitive/binary/default.py @@ -20,9 +20,9 @@ from tvm.script import tirx as T from tvm.tirx import FloatImm, PrimFunc from tvm.tirx.operator.tile_primitive import DispatchContext, fail +from tvm.tirx.operator.tile_primitive.common import MapOpType from tvm.tirx.stmt import TilePrimitiveCall -from ...common import MapOpType from ..common import init_analyzer, nki_dim from ..instruction_generator import InstructionGenerator from .utils import InstType, binary_map_ops, try_find_inst_nary @@ -32,7 +32,7 @@ def binary_trn( op: TilePrimitiveCall, binary_op: MapOpType, sctx: DispatchContext ) -> PrimFunc | None: """Generate a binary operation schedule for Trainium.""" - if not (sctx.is_trn() and sctx.scope_kind == "thread"): + if not (sctx.is_target("trn") and sctx.scope_kind == "thread"): fail("requires Trainium target and thread exec_scope") assert binary_op in binary_map_ops, f"Unsupported binary operation {binary_op}" diff --git a/python/tvm/tirx/operator/tile_primitive/trn/binary/utils.py b/python/tvm/backend/trn/operator/tile_primitive/binary/utils.py similarity index 97% rename from python/tvm/tirx/operator/tile_primitive/trn/binary/utils.py rename to python/tvm/backend/trn/operator/tile_primitive/binary/utils.py index 0f0c0e053f34..891c5e9c4875 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/binary/utils.py +++ b/python/tvm/backend/trn/operator/tile_primitive/binary/utils.py @@ -20,9 +20,10 @@ from enum import Enum from tvm.arith.analyzer import Analyzer +from tvm.backend.trn.layout import is_trainium_layout from tvm.tirx import BufferRegion, FloatImm +from tvm.tirx.operator.tile_primitive.common import MapOpType -from ...common import MapOpType from ..dim_utils import get_ewise_dim_map from ..instruction_generator import InstructionGenerator @@ -72,8 +73,8 @@ def try_find_inst_nary( valid_buffers = all( [ dst.layout and all(src.layout for src in srcs if src is not None), - dst.layout.is_trainium(), - all(src.layout.is_trainium() for src in srcs if src is not None), + is_trainium_layout(dst.layout), + all(is_trainium_layout(src.layout) for src in srcs if src is not None), dst.scope() == "trn.sbuf", all(src.scope() in ["trn.sbuf", "trn.psum"] for src in srcs if src is not None), ] diff --git a/python/tvm/tirx/operator/tile_primitive/trn/common.py b/python/tvm/backend/trn/operator/tile_primitive/common.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/common.py rename to python/tvm/backend/trn/operator/tile_primitive/common.py diff --git a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/__init__.py b/python/tvm/backend/trn/operator/tile_primitive/compose_op/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/compose_op/__init__.py rename to python/tvm/backend/trn/operator/tile_primitive/compose_op/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/binary_chain.py b/python/tvm/backend/trn/operator/tile_primitive/compose_op/binary_chain.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/compose_op/binary_chain.py rename to python/tvm/backend/trn/operator/tile_primitive/compose_op/binary_chain.py diff --git a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/binary_reduce.py b/python/tvm/backend/trn/operator/tile_primitive/compose_op/binary_reduce.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/compose_op/binary_reduce.py rename to python/tvm/backend/trn/operator/tile_primitive/compose_op/binary_reduce.py diff --git a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/compose_op.py b/python/tvm/backend/trn/operator/tile_primitive/compose_op/compose_op.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/compose_op/compose_op.py rename to python/tvm/backend/trn/operator/tile_primitive/compose_op/compose_op.py diff --git a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/reduce_negate.py b/python/tvm/backend/trn/operator/tile_primitive/compose_op/reduce_negate.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/compose_op/reduce_negate.py rename to python/tvm/backend/trn/operator/tile_primitive/compose_op/reduce_negate.py diff --git a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/unary_reduce.py b/python/tvm/backend/trn/operator/tile_primitive/compose_op/unary_reduce.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/compose_op/unary_reduce.py rename to python/tvm/backend/trn/operator/tile_primitive/compose_op/unary_reduce.py diff --git a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/utils.py b/python/tvm/backend/trn/operator/tile_primitive/compose_op/utils.py similarity index 95% rename from python/tvm/tirx/operator/tile_primitive/trn/compose_op/utils.py rename to python/tvm/backend/trn/operator/tile_primitive/compose_op/utils.py index 9fbaa524fd2e..1509482be8d0 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/compose_op/utils.py +++ b/python/tvm/backend/trn/operator/tile_primitive/compose_op/utils.py @@ -18,8 +18,7 @@ """Shared helpers for compose operator dispatches.""" from tvm.ir import Op - -from ...common import ReduceOpType +from tvm.tirx.operator.tile_primitive.common import ReduceOpType # Operation code mappings opcode_table = { diff --git a/python/tvm/tirx/operator/tile_primitive/trn/copy/__init__.py b/python/tvm/backend/trn/operator/tile_primitive/copy/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/copy/__init__.py rename to python/tvm/backend/trn/operator/tile_primitive/copy/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/trn/copy/default.py b/python/tvm/backend/trn/operator/tile_primitive/copy/default.py similarity index 96% rename from python/tvm/tirx/operator/tile_primitive/trn/copy/default.py rename to python/tvm/backend/trn/operator/tile_primitive/copy/default.py index b1a0b2078681..123698b62164 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/copy/default.py +++ b/python/tvm/backend/trn/operator/tile_primitive/copy/default.py @@ -17,6 +17,7 @@ """Implementation of copy operator dispatchs.""" +from tvm.backend.trn.layout import is_trainium_layout from tvm.script import tirx as T from tvm.tirx import PrimFunc from tvm.tirx.operator.tile_primitive import ( @@ -86,7 +87,7 @@ def transpose_schedule( if "identity" not in op.workspace: assert sctx.alloc_only, ( - "Identity tensor must be specified in workspace. Run tvm.tirx.transform.trn.TrnPrivateBufferAlloc first." # noqa: E501 + "Identity tensor must be specified in workspace. Run tvm.tirx.trn.transform.TrnPrivateBufferAlloc first." # noqa: E501 ) identity_tensor = T.buffer( (p_size, rhs_f_size), src_region.buffer.dtype, scope="trn.sbuf", buffer_name="identity" @@ -143,7 +144,7 @@ def transpose_psum_output(): if "acc_psum" not in op.workspace: assert sctx.alloc_only, ( - "Accumulation psum buffer must be specified in workspace. Run tvm.tirx.transform.trn.TrnPrivateBufferAlloc first." # noqa: E501 + "Accumulation psum buffer must be specified in workspace. Run tvm.tirx.trn.transform.TrnPrivateBufferAlloc first." # noqa: E501 ) acc_psum = T.buffer( (max_psum_banks, p_size, largest_psum_per_bank), @@ -202,9 +203,9 @@ def copy_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: dst.scope() in ["global", "trn.sbuf", "trn.psum"], src.scope() != "global" or dst.scope() != "global", (src.scope() == "global" and isinstance(src.layout, T.TileLayout)) - or (src.scope() in ["trn.sbuf", "trn.psum"] and src.layout.is_trainium()), + or (src.scope() in ["trn.sbuf", "trn.psum"] and is_trainium_layout(src.layout)), (dst.scope() == "global" and isinstance(dst.layout, T.TileLayout)) - or (dst.scope() in ["trn.sbuf", "trn.psum"] and dst.layout.is_trainium()), + or (dst.scope() in ["trn.sbuf", "trn.psum"] and is_trainium_layout(dst.layout)), ] ) @@ -232,7 +233,7 @@ def copy_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: if not inst_gen.check_partition_dim_match(src_region, dst_region): return transpose_schedule(op, inst_gen, sctx) - if src.layout.is_trainium(): + if is_trainium_layout(src.layout): inst = inst_gen.find_max_inst_size_from_one_region(src_region) inst = inst_gen.fit_inst_tile_to_region(inst, dst_region) src_to_dst = True diff --git a/python/tvm/tirx/operator/tile_primitive/trn/dim_utils.py b/python/tvm/backend/trn/operator/tile_primitive/dim_utils.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/dim_utils.py rename to python/tvm/backend/trn/operator/tile_primitive/dim_utils.py diff --git a/python/tvm/tirx/operator/tile_primitive/trn/gemm/__init__.py b/python/tvm/backend/trn/operator/tile_primitive/gemm/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/gemm/__init__.py rename to python/tvm/backend/trn/operator/tile_primitive/gemm/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/trn/gemm/default.py b/python/tvm/backend/trn/operator/tile_primitive/gemm/default.py similarity index 97% rename from python/tvm/tirx/operator/tile_primitive/trn/gemm/default.py rename to python/tvm/backend/trn/operator/tile_primitive/gemm/default.py index ca572ba781da..eba660ab6d6a 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/gemm/default.py +++ b/python/tvm/backend/trn/operator/tile_primitive/gemm/default.py @@ -21,6 +21,7 @@ import operator from tvm.arith.analyzer import Analyzer +from tvm.backend.trn.layout import is_trainium_layout from tvm.ir import assert_structural_equal from tvm.script import tirx as T from tvm.tirx import BufferRegion, PrimFunc @@ -110,7 +111,7 @@ def get_pf_dim_from_buffer_region( def matmul_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: """Schedule GEMM operation on Trainium.""" # Basic validation checks - if not (sctx.is_trn() and sctx.scope_kind == "thread"): + if not (sctx.is_target("trn") and sctx.scope_kind == "thread"): fail("requires Trainium target and thread exec_scope") # Extract arguments @@ -147,9 +148,9 @@ def matmul_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: A.dtype == B.dtype, A.scope() == "trn.sbuf" and B.scope() == "trn.sbuf", C.scope() == "trn.psum" or C.scope() == "trn.sbuf", - A.layout.is_trainium(), - B.layout.is_trainium(), - C.layout.is_trainium(), + is_trainium_layout(A.layout), + is_trainium_layout(B.layout), + is_trainium_layout(C.layout), A.layout.size("P") == B.layout.size("P"), ] ), "Invalid buffer layout and scope" @@ -252,7 +253,7 @@ def impl_C_psum(): acc_psum_shape = (max_psum_banks, p_size, largest_psum_per_bank) if "acc_psum" not in op.workspace: - assert sctx.alloc_only, "Accumulation psum buffer must be specified in workspace. Run tvm.tirx.transform.trn.TrnPrivateBufferAlloc first." # noqa: E501 + assert sctx.alloc_only, "Accumulation psum buffer must be specified in workspace. Run tvm.tirx.trn.transform.TrnPrivateBufferAlloc first." # noqa: E501 acc_psum = T.buffer( acc_psum_shape, "float32", diff --git a/python/tvm/tirx/operator/tile_primitive/trn/instruction_generator.py b/python/tvm/backend/trn/operator/tile_primitive/instruction_generator.py similarity index 98% rename from python/tvm/tirx/operator/tile_primitive/trn/instruction_generator.py rename to python/tvm/backend/trn/operator/tile_primitive/instruction_generator.py index 58163d4b148a..d501aab82616 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/instruction_generator.py +++ b/python/tvm/backend/trn/operator/tile_primitive/instruction_generator.py @@ -25,6 +25,7 @@ import tvm from tvm.arith.analyzer import Analyzer +from tvm.backend.trn.layout import is_trainium_layout from tvm.ir import Range from tvm.script import tirx as T from tvm.tirx import BufferRegion, PrimExpr, Var @@ -379,7 +380,7 @@ def fill_in_block_dim( for i in reversed(dims): for j in reversed(range(seps[i], seps[i + 1])): it = shards[j] - is_partition = it.axis.name == "P" if layout.is_trainium() else False + is_partition = it.axis.name == "P" if is_trainium_layout(layout) else False logical_iter_dims = bind_iters[i][j - seps[i]] for d in range(-1, len(logical_iter_dims)): next_logical_stride = ( @@ -479,7 +480,9 @@ def _get_inst_logical_iter_list( is_free_dim: bool = True, ) -> LogicalIterList: layout = self.split_layout_views[buffer_region] - assert layout.is_trainium(), " Cannot propagate instruction information from HBM tensor" + assert is_trainium_layout(layout), ( + " Cannot propagate instruction information from HBM tensor" + ) iters = layout.shard seps = self.seps[buffer_region] ret = [[[] for _ in range(seps[i], seps[i + 1])] for i in range(len(buffer_region.region))] @@ -633,7 +636,7 @@ def fit_inst_tile_to_region( mapped_data_iter = to_layout.shard[data_iter_map[i]] if inst_stride_from is None: inst_stride_from = data_iter.stride - if not to_layout.is_trainium() and mapped_data_iter.stride != 1: + if not is_trainium_layout(to_layout) and mapped_data_iter.stride != 1: # dma copy must be contiguous on hbm break inst_stride_to = mapped_data_iter.stride @@ -649,7 +652,7 @@ def check_partition_dim_match( dim_map = self.dim_mapper.get_dim_map(buffer_region_1, buffer_region_2) layout_1 = self.split_layout_views[buffer_region_1] layout_2 = self.split_layout_views[buffer_region_2] - if not layout_1.is_trainium() or not layout_2.is_trainium(): + if not is_trainium_layout(layout_1) or not is_trainium_layout(layout_2): return True seps_1 = self.seps[buffer_region_1] seps_2 = self.seps[buffer_region_2] diff --git a/python/tvm/tirx/operator/tile_primitive/trn/private_alloc.py b/python/tvm/backend/trn/operator/tile_primitive/private_alloc.py similarity index 96% rename from python/tvm/tirx/operator/tile_primitive/trn/private_alloc.py rename to python/tvm/backend/trn/operator/tile_primitive/private_alloc.py index fe3f0a54bba1..830ab45f2d11 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/private_alloc.py +++ b/python/tvm/backend/trn/operator/tile_primitive/private_alloc.py @@ -17,6 +17,9 @@ from typing import Any +from tvm.backend.trn.operator.tile_primitive.common import init_analyzer, nki_dim +from tvm.backend.trn.operator.tile_primitive.dim_utils import get_ewise_dim_map +from tvm.backend.trn.operator.tile_primitive.instruction_generator import InstructionGenerator from tvm.script import tirx as T from tvm.tirx import Buffer, FloatImm, Stmt from tvm.tirx.operator.tile_primitive.dispatch_context import DispatchContext @@ -29,9 +32,6 @@ UnaryReduce, ) from tvm.tirx.operator.tile_primitive.registry import f_op_dispatcher -from tvm.tirx.operator.tile_primitive.trn.common import init_analyzer, nki_dim -from tvm.tirx.operator.tile_primitive.trn.dim_utils import get_ewise_dim_map -from tvm.tirx.operator.tile_primitive.trn.instruction_generator import InstructionGenerator from tvm.tirx.stmt import TilePrimitiveCall diff --git a/python/tvm/tirx/operator/tile_primitive/trn/reduction/__init__.py b/python/tvm/backend/trn/operator/tile_primitive/reduction/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/reduction/__init__.py rename to python/tvm/backend/trn/operator/tile_primitive/reduction/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/trn/reduction/default.py b/python/tvm/backend/trn/operator/tile_primitive/reduction/default.py similarity index 94% rename from python/tvm/tirx/operator/tile_primitive/trn/reduction/default.py rename to python/tvm/backend/trn/operator/tile_primitive/reduction/default.py index f7a7b886d0f9..0a1d8b6cbbd9 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/reduction/default.py +++ b/python/tvm/backend/trn/operator/tile_primitive/reduction/default.py @@ -18,8 +18,8 @@ """Reduction dispatch variant registrations.""" from tvm.tirx.operator.tile_primitive import register_dispatch +from tvm.tirx.operator.tile_primitive.common import ReduceOpType -from ...common import ReduceOpType from .utils import reduction_trn for _op_name, _op_type in { diff --git a/python/tvm/tirx/operator/tile_primitive/trn/reduction/utils.py b/python/tvm/backend/trn/operator/tile_primitive/reduction/utils.py similarity index 95% rename from python/tvm/tirx/operator/tile_primitive/trn/reduction/utils.py rename to python/tvm/backend/trn/operator/tile_primitive/reduction/utils.py index 1d9840e5d674..43bd92d56d75 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/reduction/utils.py +++ b/python/tvm/backend/trn/operator/tile_primitive/reduction/utils.py @@ -17,12 +17,13 @@ """Shared helpers for reduction schedules.""" +from tvm.backend.trn.layout import is_trainium_layout from tvm.script import tirx as T from tvm.tirx import PrimFunc from tvm.tirx.operator.tile_primitive import DispatchContext, fail +from tvm.tirx.operator.tile_primitive.common import ReduceOpType from tvm.tirx.stmt import TilePrimitiveCall -from ...common import ReduceOpType from ..common import init_analyzer, nki_dim from ..dim_utils import get_reduction_dim_map from ..instruction_generator import InstructionGenerator @@ -46,7 +47,7 @@ def generate_intermediate_buffer( check_workspace_buffer(intermediate_buffer, intermediate_shape, "trn.sbuf") else: assert sctx.alloc_only, ( - "Partial reduce buffer must be specified in workspace. Run tvm.tirx.transform.trn.TrnPrivateBufferAlloc first." # noqa: E501 + "Partial reduce buffer must be specified in workspace. Run tvm.tirx.trn.transform.TrnPrivateBufferAlloc first." # noqa: E501 ) intermediate_buffer = T.buffer( intermediate_shape, @@ -73,7 +74,7 @@ def reduction_trn( Returns: Optional[PrimFunc]: The scheduled function, or None if not applicable. """ - if not (sctx.is_trn() and sctx.scope_kind == "thread"): + if not (sctx.is_target("trn") and sctx.scope_kind == "thread"): fail("requires Trainium target and thread exec_scope") dst_buffer_region, src_buffer_region, axes, accum = op.args[:4] @@ -93,8 +94,8 @@ def reduction_trn( src.layout and dst.layout, src.scope() == "trn.sbuf" or src.scope() == "trn.psum", dst.scope() == "trn.sbuf", - src.layout.is_trainium(), - dst.layout.is_trainium(), + is_trainium_layout(src.layout), + is_trainium_layout(dst.layout), src.layout.size("P") == dst.layout.size("P"), ] ), "Invalid layout" diff --git a/python/tvm/tirx/operator/tile_primitive/trn/select/__init__.py b/python/tvm/backend/trn/operator/tile_primitive/select/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/select/__init__.py rename to python/tvm/backend/trn/operator/tile_primitive/select/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/trn/select/default.py b/python/tvm/backend/trn/operator/tile_primitive/select/default.py similarity index 97% rename from python/tvm/tirx/operator/tile_primitive/trn/select/default.py rename to python/tvm/backend/trn/operator/tile_primitive/select/default.py index 27136a3ac342..07486e6ae7ab 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/select/default.py +++ b/python/tvm/backend/trn/operator/tile_primitive/select/default.py @@ -17,6 +17,7 @@ """Implementation of select schedules.""" +from tvm.backend.trn.layout import is_trainium_layout from tvm.script import tirx as T from tvm.tirx import BufferRegion, FloatImm, PrimFunc, TilePrimitiveCall from tvm.tirx.operator.tile_primitive import ( @@ -63,8 +64,8 @@ def select_trn(op: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: buffer_conditions = [ dst.buffer.layout and true_value.buffer.layout, dst.buffer.scope() == "trn.sbuf" and true_value.buffer.scope() == "trn.sbuf", - true_value.buffer.layout.is_trainium(), - dst.buffer.layout.is_trainium(), + is_trainium_layout(true_value.buffer.layout), + is_trainium_layout(dst.buffer.layout), ] if not all(buffer_conditions): diff --git a/python/tvm/tirx/operator/tile_primitive/trn/unary/__init__.py b/python/tvm/backend/trn/operator/tile_primitive/unary/__init__.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/unary/__init__.py rename to python/tvm/backend/trn/operator/tile_primitive/unary/__init__.py diff --git a/python/tvm/tirx/operator/tile_primitive/trn/unary/default.py b/python/tvm/backend/trn/operator/tile_primitive/unary/default.py similarity index 96% rename from python/tvm/tirx/operator/tile_primitive/trn/unary/default.py rename to python/tvm/backend/trn/operator/tile_primitive/unary/default.py index e336daa717a3..69facea7f97b 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/unary/default.py +++ b/python/tvm/backend/trn/operator/tile_primitive/unary/default.py @@ -19,9 +19,9 @@ from tvm.tirx import FloatImm, PrimFunc from tvm.tirx.operator.tile_primitive import DispatchContext, fail +from tvm.tirx.operator.tile_primitive.common import MapOpType from tvm.tirx.stmt import TilePrimitiveCall -from ...common import MapOpType from ..common import init_analyzer from ..instruction_generator import InstructionGenerator from .utils import ( @@ -35,7 +35,7 @@ def unary_trn(op: TilePrimitiveCall, unary_op: MapOpType, sctx: DispatchContext) -> PrimFunc | None: """Schedule unary operation on Trainium.""" # Check execution environment - if not (sctx.is_trn() and sctx.scope_kind == "thread"): + if not (sctx.is_target("trn") and sctx.scope_kind == "thread"): fail("requires Trainium target and thread exec_scope") # Extract operation arguments diff --git a/python/tvm/tirx/operator/tile_primitive/trn/unary/utils.py b/python/tvm/backend/trn/operator/tile_primitive/unary/utils.py similarity index 96% rename from python/tvm/tirx/operator/tile_primitive/trn/unary/utils.py rename to python/tvm/backend/trn/operator/tile_primitive/unary/utils.py index 7a757609b09a..dc58862bb4f9 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/unary/utils.py +++ b/python/tvm/backend/trn/operator/tile_primitive/unary/utils.py @@ -18,10 +18,11 @@ """Shared helpers, op tables, and validation functions for unary operator dispatches.""" from tvm.arith.analyzer import Analyzer +from tvm.backend.trn.layout import is_trainium_layout from tvm.script import tirx as T from tvm.tirx import BufferRegion, FloatImm +from tvm.tirx.operator.tile_primitive.common import MapOpType -from ...common import MapOpType from ..common import nki_dim from ..dim_utils import get_ewise_dim_map from ..instruction_generator import InstructionGenerator @@ -56,8 +57,8 @@ def try_find_inst_unary( src.layout and dst.layout, src.scope() in ("trn.sbuf", "trn.psum"), dst.scope() == "trn.sbuf", - src.layout.is_trainium(), - dst.layout.is_trainium(), + is_trainium_layout(src.layout), + is_trainium_layout(dst.layout), ] ) @@ -97,7 +98,7 @@ def get_const_bias_tensor(bias, shape, dtype, workspace, sctx): """Create or retrieve a constant bias tensor.""" if "const_bias" not in workspace: assert sctx.alloc_only, ( - "Constant bias tensor must be specified in workspace. Run tvm.tirx.transform.trn.TrnPrivateBufferAlloc first." # noqa: E501 + "Constant bias tensor must be specified in workspace. Run tvm.tirx.trn.transform.TrnPrivateBufferAlloc first." # noqa: E501 ) # Create new bias buffer bias_buffer = T.buffer(shape, dtype, scope="trn.sbuf", buffer_name="const_bias") diff --git a/python/tvm/tirx/operator/tile_primitive/trn/unary/with_bias_scale.py b/python/tvm/backend/trn/operator/tile_primitive/unary/with_bias_scale.py similarity index 96% rename from python/tvm/tirx/operator/tile_primitive/trn/unary/with_bias_scale.py rename to python/tvm/backend/trn/operator/tile_primitive/unary/with_bias_scale.py index 399d8cfa6d11..26fb5670270a 100644 --- a/python/tvm/tirx/operator/tile_primitive/trn/unary/with_bias_scale.py +++ b/python/tvm/backend/trn/operator/tile_primitive/unary/with_bias_scale.py @@ -19,9 +19,9 @@ from tvm.tirx import BufferRegion, PrimFunc from tvm.tirx.operator.tile_primitive import DispatchContext, fail +from tvm.tirx.operator.tile_primitive.common import MapOpType from tvm.tirx.stmt import TilePrimitiveCall -from ...common import MapOpType from ..binary import try_find_inst_nary from ..common import init_analyzer from ..instruction_generator import InstructionGenerator @@ -33,7 +33,7 @@ def unary_with_bias_scale_trn( ) -> PrimFunc | None: """Schedule unary operation with bias and scale on Trainium.""" # Check execution environment - if not (sctx.is_trn() and sctx.scope_kind == "thread"): + if not (sctx.is_target("trn") and sctx.scope_kind == "thread"): fail("requires Trainium target and thread exec_scope") # Extract operation arguments with defaults diff --git a/python/tvm/tirx/operator/tile_primitive/trn/workspace_utils.py b/python/tvm/backend/trn/operator/tile_primitive/workspace_utils.py similarity index 100% rename from python/tvm/tirx/operator/tile_primitive/trn/workspace_utils.py rename to python/tvm/backend/trn/operator/tile_primitive/workspace_utils.py diff --git a/python/tvm/backend/trn/pipeline.py b/python/tvm/backend/trn/pipeline.py new file mode 100644 index 000000000000..d0663218e651 --- /dev/null +++ b/python/tvm/backend/trn/pipeline.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Trainium TIRX pipeline entrypoints.""" + +import tvm +from tvm import tirx +from tvm.tirx.compilation_pipeline import finalize_host_passes + +from . import transform as trn_transform + + +def trn_pipeline(): + """The Trainium pipeline used in tvm.tirx.build.""" + + @tvm.transform.module_pass(opt_level=0) + def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: + """Lower TIRx for the Trainium backend.""" + passes = [ + trn_transform.TrnPrivateBufferAlloc(), + trn_transform.TrnNaiveAllocator(), + tirx.transform.TilePrimitiveDispatch(), + trn_transform.LowerTrainiumLayout(), + tvm.s_tir.transform.DecorateDeviceScope(), + tirx.transform.StmtSimplify(), + tirx.transform.LowerTIRxOpaque(), + tvm.s_tir.transform.LoopPartition(), + tvm.s_tir.transform.HoistIfThenElse(), + tirx.transform.StmtSimplify(), + tirx.transform.RemoveNoOp(), + tirx.transform.AnnotateEntryFunc(), + tirx.transform.SplitHostDevice(), + tirx.transform.MakePackedAPI(), + ] + return tvm.ir.transform.Sequential(passes)(mod) + + return _pipeline, finalize_host_passes, finalize_device_passes_trn + + +def finalize_device_passes_trn(): # pylint: disable=unused-argument + """The finalization passes for the Trainium backend.""" + return tvm.ir.transform.Sequential([tirx.transform.StmtSimplify()]) + + +__all__ = ["finalize_device_passes_trn", "trn_pipeline"] diff --git a/python/tvm/backend/trn/script.py b/python/tvm/backend/trn/script.py new file mode 100644 index 000000000000..5d507f462bd0 --- /dev/null +++ b/python/tvm/backend/trn/script.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Trainium TVMScript namespaces.""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from . import op as _trn_op + +OpWrapper = Callable[[Callable[..., Any]], Callable[..., Any]] + + +def _default_op_wrapper() -> OpWrapper: + from tvm.tirx.script.builder.ir import _op_wrapper # pylint: disable=import-outside-toplevel + + return _op_wrapper + + +class NKINamespace: + """The NKI instructions submodule.""" + + def __init__(self, op_wrapper: OpWrapper | None = None): + wrap = op_wrapper or _default_op_wrapper() + self.load = wrap(_trn_op.nki_load) + self.store = wrap(_trn_op.nki_store) + self.tensor_copy = wrap(_trn_op.nki_tensor_copy) + self.matmul = wrap(_trn_op.nki_matmul) + self.activation = wrap(_trn_op.nki_activation) + self.activation_reduce = wrap(_trn_op.nki_activation_reduce) + self.reciprocal = wrap(_trn_op.nki_reciprocal) + self.tensorreduce = wrap(_trn_op.nki_tensorreduce) + self.tensortensor = wrap(_trn_op.nki_tensortensor) + self.tensorscalar = wrap(_trn_op.nki_tensorscalar) + self.tensorscalar_reduce = wrap(_trn_op.nki_tensorscalar_reduce) + self.scalar_tensor_tensor = wrap(_trn_op.nki_scalar_tensor_tensor) + self.scalar_tensor_scalar = wrap(_trn_op.nki_scalar_tensor_scalar) + self.memset = wrap(_trn_op.nki_memset) + self.identity = wrap(_trn_op.nki_identity) + self.affine_select = wrap(_trn_op.nki_affine_select) + + +__all__ = ["NKINamespace", "OpWrapper"] diff --git a/python/tvm/backend/trn/target_tags.py b/python/tvm/backend/trn/target_tags.py new file mode 100644 index 000000000000..54c63d722600 --- /dev/null +++ b/python/tvm/backend/trn/target_tags.py @@ -0,0 +1,36 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""AWS Trainium target tags.""" + +from tvm.target import register_tag + + +def _register_aws_trn1_tag(name, cores): + register_tag( + name, + { + "kind": "trn", + "num-cores": cores, + "partition_size": 128, + "max_sbuf_size_per_partition": 196608, + "max_psum_size_per_partition": 16384, + }, + ) + + +_register_aws_trn1_tag("aws/trn1/trn1.2xlarge", 2) +_register_aws_trn1_tag("aws/trn1/trn1.32xlarge", 32) diff --git a/python/tvm/tirx/transform/trn/__init__.py b/python/tvm/backend/trn/transform/__init__.py similarity index 68% rename from python/tvm/tirx/transform/trn/__init__.py rename to python/tvm/backend/trn/transform/__init__.py index 0aaf3062c8f3..84dac89cec8b 100644 --- a/python/tvm/tirx/transform/trn/__init__.py +++ b/python/tvm/backend/trn/transform/__init__.py @@ -1,4 +1,3 @@ -# isort: skip_file # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -18,17 +17,30 @@ """Trainium-specific TIRX transformations.""" # pylint: disable=invalid-name -# Fork-only TIRX-specific passes. They decorate their pass body with -# `@prim_func_pass(...)` at module-load time, which triggers an FFI call to -# construct PassInfo -- not runtime-safe. Loading them lazily preserves -# apache's discipline that `import tvm.tirx.transform.trn` performs no -# compiler-side FFI calls (required for `TVM_USE_RUNTIME_LIB=1`). +from tvm_ffi import get_global_func + +import tvm +from tvm import tirx + _LAZY_TRANSFORMS = { "TrnNaiveAllocator": ".naive_allocator", "TrnPrivateBufferAlloc": ".private_buffer_alloc", } +def LowerTrainiumLayout(): + """Lower Trainium layouts to backend physical buffer shapes and indices.""" + return get_global_func("tirx.backend.trn.transform.LowerTrainiumLayout")() + + +def LowerTIRx(): + """Lower TIRx tile primitive calls for the Trainium backend.""" + return tvm.ir.transform.Sequential( + [tirx.transform.TilePrimitiveDispatch(), LowerTrainiumLayout()], + name="tirx.backend.trn.LowerTIRx", + ) + + def __getattr__(name): target = _LAZY_TRANSFORMS.get(name) if target is None: @@ -36,3 +48,6 @@ def __getattr__(name): from importlib import import_module # pylint: disable=import-outside-toplevel return getattr(import_module(target, __name__), name) + + +__all__ = ["LowerTIRx", "LowerTrainiumLayout", "TrnNaiveAllocator", "TrnPrivateBufferAlloc"] diff --git a/python/tvm/tirx/transform/trn/naive_allocator.py b/python/tvm/backend/trn/transform/naive_allocator.py similarity index 98% rename from python/tvm/tirx/transform/trn/naive_allocator.py rename to python/tvm/backend/trn/transform/naive_allocator.py index 1720a32d6938..ccab7add398c 100644 --- a/python/tvm/tirx/transform/trn/naive_allocator.py +++ b/python/tvm/backend/trn/transform/naive_allocator.py @@ -21,10 +21,9 @@ from tvm.tirx import AllocBuffer, IntImm from tvm.tirx.buffer import Buffer from tvm.tirx.stmt_functor import StmtVisitor +from tvm.tirx.transform.common import BufferReplacer from tvm.tirx.transform.function_pass import prim_func_pass -from ..common import BufferReplacer - def is_const_shape(buffer: Buffer) -> bool: for i in buffer.shape: diff --git a/python/tvm/tirx/transform/trn/private_buffer_alloc.py b/python/tvm/backend/trn/transform/private_buffer_alloc.py similarity index 100% rename from python/tvm/tirx/transform/trn/private_buffer_alloc.py rename to python/tvm/backend/trn/transform/private_buffer_alloc.py diff --git a/python/tvm/backend/vulkan/__init__.py b/python/tvm/backend/vulkan/__init__.py new file mode 100644 index 000000000000..343875aa8a02 --- /dev/null +++ b/python/tvm/backend/vulkan/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Vulkan-owned backend hooks.""" + + +def register_backend(): + """Register Vulkan-owned Python semantics.""" + return None + + +__all__ = ["register_backend"] diff --git a/python/tvm/backend/webgpu/__init__.py b/python/tvm/backend/webgpu/__init__.py new file mode 100644 index 000000000000..cb682a56d228 --- /dev/null +++ b/python/tvm/backend/webgpu/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""WebGPU-owned backend hooks.""" + + +def register_backend(): + """Register WebGPU-owned Python semantics.""" + return None + + +__all__ = ["register_backend"] diff --git a/python/tvm/s_tir/tensor_intrin/cuda.py b/python/tvm/s_tir/tensor_intrin/cuda.py index 0e2047af327f..fcd88fdae37e 100644 --- a/python/tvm/s_tir/tensor_intrin/cuda.py +++ b/python/tvm/s_tir/tensor_intrin/cuda.py @@ -580,7 +580,7 @@ def mma_fill_impl(a: T.handle) -> None: for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"): T.evaluate( - T.mma_fill_legacy(local_size, C_warp.data, C_warp.elem_offset, dtype=dtype) + T.cuda.mma_fill_legacy(local_size, C_warp.data, C_warp.elem_offset, dtype=dtype) ) return mma_fill_desc, mma_fill_impl @@ -637,7 +637,7 @@ def mma_store_impl(a: T.handle, c: T.handle) -> None: for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"): T.evaluate( - T.mma_store_legacy( + T.cuda.mma_store_legacy( M_DIM, N_DIM, C.access_ptr("w"), diff --git a/python/tvm/s_tir/tensor_intrin/metal.py b/python/tvm/s_tir/tensor_intrin/metal.py index a789581d4b0e..894aeea65615 100644 --- a/python/tvm/s_tir/tensor_intrin/metal.py +++ b/python/tvm/s_tir/tensor_intrin/metal.py @@ -60,7 +60,7 @@ def impl(a: T.handle) -> None: with T.sblock("root"): T.reads() T.writes(A[0:col, 0:row]) - T.make_filled_simdgroup_matrix( + T.metal.make_filled_simdgroup_matrix( A.data, index=get_simdgroup_index(A, d1, col, row), value=T.float32(0), @@ -122,7 +122,7 @@ def impl(a: T.handle, c: T.handle) -> None: with T.sblock("root"): T.reads(A[0:col, 0:row]) T.writes(C[0:col, 0:row]) - T.simdgroup_load( + T.metal.simdgroup_load( C.data, index=get_simdgroup_index(C, d1, col, row), ptr=A.access_ptr("r"), @@ -179,7 +179,7 @@ def impl(a: T.handle, c: T.handle) -> None: with T.sblock("root"): T.reads(A[0:col, 0:row]) T.writes(C[0:col, 0:row]) - T.simdgroup_store( + T.metal.simdgroup_store( A.data, index=get_simdgroup_index(A, s1, col, row), ptr=C.access_ptr("w"), @@ -223,7 +223,7 @@ def impl(a: T.handle, b: T.handle, c: T.handle) -> None: with T.sblock("root"): T.reads(C[0:m_dim, 0:n_dim], A[0:m_dim, 0:k_dim], B[0:k_dim, 0:n_dim]) T.writes(C[0:m_dim, 0:n_dim]) - T.simdgroup_multiply_accumulate( + T.metal.simdgroup_multiply_accumulate( C.data, get_simdgroup_index(C, c1, m_dim, n_dim), A.data, diff --git a/python/tvm/target/tag_registry/__init__.py b/python/tvm/target/tag_registry/__init__.py index 5ace43847e8f..114edc448c27 100644 --- a/python/tvm/target/tag_registry/__init__.py +++ b/python/tvm/target/tag_registry/__init__.py @@ -21,13 +21,9 @@ """ from . import registry -from . import cuda from . import arm_cpu from . import riscv_cpu from . import aws_cpu -from . import metal -from . import hexagon -from . import adreno # Validate all tags at import time registry.list_tags() diff --git a/python/tvm/tirx/__init__.py b/python/tvm/tirx/__init__.py index 4378a9dfbe6c..5e8a2184bdad 100644 --- a/python/tvm/tirx/__init__.py +++ b/python/tvm/tirx/__init__.py @@ -64,16 +64,6 @@ tvm_bmma_sync, tvm_fill_fragment, ) -from .op import ptx_mma, ptx_mma_sp, mma_store, mma_fill -from .op import ptx_mma_legacy, ptx_mma_sp_legacy, mma_store_legacy, mma_fill_legacy -from .op import ptx_ldmatrix, ptx_cp_async, ptx_cp_async_bulk, ptx_cp_async_bulk_shared_to_cluster -from .op import ptx_ldmatrix_legacy, ptx_cp_async_legacy -from .op import ( - make_filled_simdgroup_matrix, - simdgroup_load, - simdgroup_multiply_accumulate, - simdgroup_store, -) from .op import vectorlow, vectorhigh, vectorcombine from .op import infinity, reinterpret from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz @@ -113,16 +103,11 @@ from tvm.base import _RUNTIME_ONLY as _RUNTIME_ONLY_TIRX # pylint: disable=wrong-import-position if not _RUNTIME_ONLY_TIRX: - # CUDA codegen registration. Each family module registers codegen via - # @register_codegen (hand-written ops) and ptx_intrinsic / - # cuda_helper_intrinsic (schema-declared ops); the schema declarations - # also inject Python wrappers into `tvm.tirx.op`. Must come before - # anything downstream that looks up wrappers or the codegen registry. - from .operator.intrinsics import cuda as _intrinsics_cuda from .build import build from .compilation_pipeline import ( get_tir_pipeline, get_default_tir_pipeline, + register_tir_pipeline, ) import tvm.script diff --git a/python/tvm/tirx/backend/__init__.py b/python/tvm/tirx/backend/__init__.py index 862bed83d291..6f5333629f51 100644 --- a/python/tvm/tirx/backend/__init__.py +++ b/python/tvm/tirx/backend/__init__.py @@ -14,6 +14,4 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""The TIR Adreno backend passes""" - -from . import adreno +"""TIRx backend compatibility package.""" diff --git a/python/tvm/tirx/bench.py b/python/tvm/tirx/bench.py index d12ff2e3d04d..aa0676375905 100644 --- a/python/tvm/tirx/bench.py +++ b/python/tvm/tirx/bench.py @@ -613,7 +613,7 @@ def _leader(self, leader: None | tvm.tirx.PrimExpr | bool): @T.inline def init(self, group_id: tvm.tirx.PrimExpr): if self.profiler_enabled: - T.timer_init_cuda( + T.cuda.timer_init( self.buffer.data, self.profiler_tag.data, self.profiler_write_offset.data, @@ -624,7 +624,7 @@ def init(self, group_id: tvm.tirx.PrimExpr): @T.inline def start(self, event_type: Enum, leader: None | tvm.tirx.PrimExpr | bool = None): if self.profiler_enabled: - T.timer_start_cuda( + T.cuda.timer_start( event_type, self.buffer.data, self.profiler_tag.data, @@ -636,7 +636,7 @@ def start(self, event_type: Enum, leader: None | tvm.tirx.PrimExpr | bool = None @T.inline def end(self, event_type: Enum, leader: None | tvm.tirx.PrimExpr | bool = None): if self.profiler_enabled: - T.timer_end_cuda( + T.cuda.timer_end( event_type, self.buffer.data, self.profiler_tag.data, @@ -648,7 +648,7 @@ def end(self, event_type: Enum, leader: None | tvm.tirx.PrimExpr | bool = None): @T.inline def finalize(self, leader: None | tvm.tirx.PrimExpr | bool = None): if self.profiler_enabled: - T.timer_finalize_cuda( + T.cuda.timer_finalize( self.buffer.data, self.profiler_tag.data, self.profiler_write_offset.data, diff --git a/python/tvm/tirx/compilation_pipeline.py b/python/tvm/tirx/compilation_pipeline.py index f79af3493f28..d2847332b4a7 100644 --- a/python/tvm/tirx/compilation_pipeline.py +++ b/python/tvm/tirx/compilation_pipeline.py @@ -99,33 +99,6 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I return _pipeline, finalize_host_passes, finalize_device_passes -def trn_pipeline(): - """The Trainium pipeline used in tvm.tirx.build""" - - @tvm.transform.module_pass(opt_level=0) - def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule: - """The default lowering passes for TRN backend.""" - tvm.transform.PassContext.current() - passes = [ - tirx.transform.trn.TrnPrivateBufferAlloc(), - tirx.transform.trn.TrnNaiveAllocator(), - tirx.transform.LowerTIRx(), - tvm.s_tir.transform.DecorateDeviceScope(), - tirx.transform.StmtSimplify(), - tirx.transform.LowerTIRxOpaque(), - tvm.s_tir.transform.LoopPartition(), - tvm.s_tir.transform.HoistIfThenElse(), - tirx.transform.StmtSimplify(), - tirx.transform.RemoveNoOp(), - tirx.transform.AnnotateEntryFunc(), - tirx.transform.SplitHostDevice(), - tirx.transform.MakePackedAPI(), - ] - return tvm.ir.transform.Sequential(passes)(mod) - - return _pipeline, finalize_host_passes, finalize_device_passes_trn - - def finalize_host_passes(): # pylint: disable=unused-argument """The default finalization passes for TIR backend.""" host_pass_list = [ @@ -153,14 +126,14 @@ def finalize_device_passes_tirx(): # pylint: disable=unused-argument return tvm.ir.transform.Sequential(device_pass_list) -def finalize_device_passes_trn(): # pylint: disable=unused-argument - """The default finalization passes for TRN backend.""" - device_pass_list = [tirx.transform.StmtSimplify()] - return tvm.ir.transform.Sequential(device_pass_list) +# global map of pre-built pipelines +PIPELINE_MAP = {"default": default_tir_pipeline, "tirx": tirx_pipeline} -# global map of pre-built pipelines -PIPELINE_MAP = {"default": default_tir_pipeline, "tirx": tirx_pipeline, "trn": trn_pipeline} +def register_tir_pipeline(name: str, pipeline_factory) -> None: + """Register a named TIR pipeline factory.""" + + PIPELINE_MAP[name] = pipeline_factory def get_tir_pipeline(name: str | None = None, **kwargs) -> tvm.transform.Pass: diff --git a/python/tvm/tirx/lang/alloc_pool.py b/python/tvm/tirx/lang/alloc_pool.py index 48bb9929c618..b45bcef60cfb 100644 --- a/python/tvm/tirx/lang/alloc_pool.py +++ b/python/tvm/tirx/lang/alloc_pool.py @@ -2,528 +2,7 @@ # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""SMEM and TMEM bump-allocator pools for TIRX kernels.""" +# to you under the Apache License, Version 2.0. +"""Compatibility redirect for CUDA allocation pool helpers.""" -from __future__ import annotations - -import functools -import operator - -from tvm import DataType -from tvm.tirx.layout import S, TCol, TileLayout, TLane - -# --------------------------------------------------------------------------- -# ir_builder helpers — imported lazily to avoid circular deps at module level -# --------------------------------------------------------------------------- - -_ir = None - - -def _get_ir(): - global _ir - if _ir is None: - from tvm.tirx.script.builder import ir as _mod - - _ir = _mod - return _ir - - -def _get_frame(): - from tvm.tirx.script.builder import frame - - return frame - - -# --------------------------------------------------------------------------- -# Shared utilities -# --------------------------------------------------------------------------- - -_POOL_UNSET = object() - - -def _default_tmem_layout(rows, cols): - return TileLayout(S[(rows, cols) : (1 @ TLane, 1 @ TCol)]) - - -def _emit_stmt(expr): - ir = _get_ir() - ir.add_to_parent(ir.evaluate(expr)) - - -def _shape_product(shape): - return functools.reduce(operator.mul, shape, 1) - - -def _auto_swizzle_mode(dtype): - """Select the default MMA swizzle mode for a shared-memory allocation.""" - from tvm.tirx.operator.tile_primitive.cuda.tma_utils import SwizzleMode - - del dtype - return SwizzleMode.SWIZZLE_128B_ATOM - - -def _swizzle_atom_bytes(swizzle_mode): - """Return the row width (in bytes) of one swizzle atom for *swizzle_mode*.""" - from tvm.tirx.operator.tile_primitive.cuda.tma_utils import SwizzleMode - - return { - SwizzleMode.SWIZZLE_NONE: 0, - SwizzleMode.SWIZZLE_32B_ATOM: 32, - SwizzleMode.SWIZZLE_64B_ATOM: 64, - SwizzleMode.SWIZZLE_128B_ATOM: 128, - }[swizzle_mode] - - -def _suggest_swizzle_for_row_bytes(row_bytes): - """Pick the largest valid swizzle mode whose atom row fits within *row_bytes*.""" - - for atom_bytes, mode in ( - (128, "SWIZZLE_128B_ATOM"), - (64, "SWIZZLE_64B_ATOM"), - (32, "SWIZZLE_32B_ATOM"), - ): - if row_bytes >= atom_bytes and row_bytes % atom_bytes == 0: - return mode - return "SWIZZLE_NONE" - - -def _validate_mma_alloc_shape(shape, dtype, swizzle_mode): - """Validate that *shape* / *dtype* / *swizzle_mode* are mutually compatible. - - ``mma_shared_layout`` tiles a swizzle atom of shape ``[8, swizzle_bytes / dtype_bytes]`` - over the last two logical dimensions of *shape*. If the row width or row count of - the request is smaller than (or not a multiple of) the atom, the underlying - ``Layout.tile_to`` lowers to a ``floordiv``/``floormod`` by zero and raises an - opaque internal "Divide by zero" diagnostic from ``tile_tile_ops.cc``. Catch the - misconfiguration here so callers see *what* is wrong and *how* to fix it. - - Validation skipped when *swizzle_mode* is ``SWIZZLE_NONE`` (no atom). - """ - from tvm.tirx.operator.tile_primitive.cuda.tma_utils import SwizzleMode - - if swizzle_mode == SwizzleMode.SWIZZLE_NONE: - return - - if len(shape) < 2: - raise ValueError( - f"alloc_mma shape={tuple(shape)} has fewer than 2 dimensions; " - f"swizzled MMA layouts tile over the last two dims (rows, cols). " - f"Use swizzle_mode='none' for 1-D allocations." - ) - - # Only validate concrete int dims; symbolic dims fall through (the analyzer - # in C++ will still ICHECK on them, but at least we don't false-positive). - rows = shape[-2] - cols = shape[-1] - if not (isinstance(rows, int) and isinstance(cols, int)): - return - - dtype_bytes = DataType(dtype).bits // 8 - if dtype_bytes == 0: - # Sub-byte dtype (e.g. float4); ``cols`` is already in element units, so - # use a fractional check expressed via bits. - col_bits = cols * DataType(dtype).bits - atom_bits = _swizzle_atom_bytes(swizzle_mode) * 8 - if col_bits < atom_bits or col_bits % atom_bits != 0: - row_bytes = col_bits // 8 if col_bits % 8 == 0 else col_bits / 8 - atom_bytes = _swizzle_atom_bytes(swizzle_mode) - suggestion = _suggest_swizzle_for_row_bytes(col_bits // 8 if col_bits >= 8 else 0) - raise ValueError( - f"alloc_mma shape={tuple(shape)} with dtype={dtype!r} produces " - f"{row_bytes}B rows, which is incompatible with the {atom_bytes}B " - f"swizzle atom selected by {swizzle_mode.name}. " - f"Use swizzle_mode=SwizzleMode.{suggestion}, or widen shape[-1] " - f"to a multiple of " - f"{(atom_bits + DataType(dtype).bits - 1) // DataType(dtype).bits} elements." - ) - else: - row_bytes = cols * dtype_bytes - atom_bytes = _swizzle_atom_bytes(swizzle_mode) - if row_bytes < atom_bytes or row_bytes % atom_bytes != 0: - suggestion = _suggest_swizzle_for_row_bytes(row_bytes) - min_cols = atom_bytes // dtype_bytes - raise ValueError( - f"alloc_mma shape={tuple(shape)} with dtype={dtype!r} produces " - f"{row_bytes}B rows, which is incompatible with the {atom_bytes}B " - f"swizzle atom selected by {swizzle_mode.name}. " - f"Use swizzle_mode=SwizzleMode.{suggestion}, or widen shape[-1] " - f"to a multiple of {min_cols} elements (>= {atom_bytes}B at {dtype})." - ) - - # Atom rows is always 8 (see ``mma_atom_shape`` in tma_utils.py). - atom_rows = 8 - if rows < atom_rows or rows % atom_rows != 0: - raise ValueError( - f"alloc_mma shape={tuple(shape)} has shape[-2]={rows}, but the " - f"{swizzle_mode.name} atom requires shape[-2] to be a positive " - f"multiple of {atom_rows}. Use swizzle_mode='none', or widen shape[-2] " - f"to a multiple of {atom_rows}." - ) - - -# --------------------------------------------------------------------------- -# TMEMStages -# --------------------------------------------------------------------------- - - -def _meta_class(cls): - """Apply @meta_class decorator from ir_builder.""" - return _get_ir().meta_class(cls) - - -@_meta_class -class TMEMStages: - """Parse-time staged view over a TMEM buffer. - - Parameters - ---------- - buf : Buffer - The underlying TMEM buffer (e.g. f32 or f16 view). - col_start : int - First column of stage 0 in *buf*'s column space. - width : int - Number of columns per stage. - stages : int - Number of pipeline stages (default 1). - stride : int or None - Column distance between consecutive stages. When *None* (default), - equals *width* (stages are packed back-to-back). - """ - - def __init__(self, buf, col_start, width, stages=1, stride=None): - self.buf = buf - self.col_start = col_start - self.width = width - self.stages = stages - self.stride = width if stride is None else stride - - def _stage_base(self, stage): - return self.col_start + stage * self.stride - - def __getitem__(self, item): - if isinstance(item, tuple): - assert len(item) == 2, "TMEMStages expects region[stage] or region[stage, start:stop]" - stage, col_slice = item - assert isinstance(col_slice, slice), "TMEMStages tuple indexing requires a slice" - base = self._stage_base(stage) - start = 0 if col_slice.start is None else col_slice.start - stop = self.width if col_slice.stop is None else col_slice.stop - return self.buf[:, base + start : base + stop : col_slice.step] - base = self._stage_base(item) - return self.buf[:, base : base + self.width] - - -# --------------------------------------------------------------------------- -# TMEMPool -# --------------------------------------------------------------------------- - - -@_meta_class -class TMEMPool: - """Bump allocator over TMEM columns.""" - - def __init__( - self, - pool, - total_cols=512, - *, - cta_group=1, - alloc_warp=0, - dealloc_warp=None, - tmem_addr=None, - sync_after_alloc=True, - ): - # tcgen05 alloc/dealloc are warp-uniform PTX instructions: every lane - # in the chosen warp must participate, and exactly one warp in the - # CTA must execute them. The pool emits its own - # ``if warp_id() == target_warp: tcgen05.alloc(...)`` - # guard, using the cta->warp scope id ``T.warp_id()``. - # NOTE: synccheck currently false-deadlocks on kernels that declare a - # second warp-scope id (cpusim binds only one warp var); the generated - # CUDA is equivalent to ``thread_rank() // 32 == target_warp``. - self.pool = pool - self.total_cols = total_cols - self.cta_group = cta_group - self.alloc_warp = alloc_warp - self.dealloc_warp = alloc_warp if dealloc_warp is None else dealloc_warp - self.sync_after_alloc = sync_after_alloc - self.offset = 0 - self.max_offset = 0 - self._committed = False - self._deallocated = False - self._addr_buf = pool.alloc([1], "uint32", align=4) if tmem_addr is None else tmem_addr - - def _addr_slot(self): - try: - return self._addr_buf[0] - except TypeError: - return self._addr_buf - - @property - def addr(self): - return self._addr_slot() - - def _emit_warp_guard(self, target_warp, emit): - from tvm.script import tirx as T - - warp_id = T.warp_id() - with T.If(warp_id == target_warp): - with T.Then(): - emit() - - def _resolve_cols(self, shape, dtype, cols, layout=None): - if cols is not None: - return cols - bits = DataType(dtype).bits - if layout is not None: - # span("TCol") is in *element* (buffer dtype) units; one TMEM cell - # holds 32 bits regardless of the element type. - tcol_elems = int(layout.span("TCol")) - tcol_bits = tcol_elems * bits - assert tcol_bits % 32 == 0, ( - f"layout TCol span={tcol_elems} elems x {bits}b is not 32-bit aligned" - ) - return tcol_bits // 32 - assert len(shape) == 2, "TMEMPool.alloc() requires cols= for non-2D TMEM buffers" - total_bits = _shape_product(shape) * bits - rows = shape[0] - assert total_bits % (32 * rows) == 0, ( - f"Cannot infer TMEM columns from shape={shape}, dtype={dtype!r}; " - "please pass cols= explicitly" - ) - return total_bits // (32 * rows) - - def alloc(self, shape, dtype="float32", *, layout=None, cols=None, datapath=None): - """Allocate a TMEM buffer. - - Parameters - ---------- - shape, dtype, cols - Standard buffer shape / dtype / column count. - layout - Explicit ``TileLayout``. Mutually exclusive with ``datapath``. - datapath : str | None - Optional tcgen05 datapath letter (``"D"`` for M=128 full datapath, - ``"F"`` for M=64 non-``.ws`` scattered). When provided, the buffer's - layout is derived from ``tmem_datapath_layout(datapath, *shape)`` - so the row index reflects the *physical* TMEM lane occupation - (PTX ISA §9.7.16.10.5). The downstream ``.16x*b`` / ``.32x32b`` - dispatches structurally check this layout to catch mismatched - atoms (e.g. a ``.16x*b`` M=128 read against a Layout F buffer). - Defaults to ``None``, which means Layout D's identity row→lane - mapping — keep this for shape ``(128, X)`` buffers that hold - an M=128 MMA accumulator. - """ - from tvm.tirx.layout import tmem_datapath_layout - - if layout is not None and datapath is not None: - raise ValueError("TMEMPool.alloc: pass at most one of layout= and datapath=") - if datapath is not None: - assert len(shape) == 2, "TMEMPool.alloc: datapath= requires a 2-D shape" - layout = tmem_datapath_layout(datapath, shape[0], shape[1]) - - ir = _get_ir() - cols = self._resolve_cols(shape, dtype, cols, layout) - col_start = self.offset - col_end = col_start + cols - assert col_end <= self.total_cols, f"TMEM overflow: {col_end} > {self.total_cols}" - if layout is None: - assert len(shape) == 2, "TMEMPool.alloc() requires layout= for non-2D TMEM buffers" - layout = _default_tmem_layout(shape[0], shape[1]) - res = ir.decl_buffer(shape, dtype, scope="tmem", allocated_addr=col_start, layout=layout) - self.offset = col_end - self.max_offset = max(self.max_offset, self.offset) - return res - - def alloc_sf(self, shape, dtype, *, sf_per_mma, sf_reuse=1): - """Allocate a tcgen05 block-scaled SF TMEM buffer with an inferred layout. - - ``shape`` last two dims are ``(rows, SF_K * sf_reuse)`` (the last dim is - what gemm dispatch iterates over). When ``shape`` has 3 dims, the first - is treated as a pipe-depth outer. - """ - from tvm.tirx.operator.tile_primitive.cuda.gemm_async.tcgen05 import sf_tmem_layout - - if len(shape) == 2: - pipe_depth, rows, last = None, shape[0], shape[1] - elif len(shape) == 3: - pipe_depth, rows, last = shape[0], shape[1], shape[2] - else: - raise ValueError( - f"alloc_sf expects 2D (rows, SF_K*sf_reuse) or 3D " - f"(pipe_depth, rows, SF_K*sf_reuse); got shape={shape}" - ) - assert last % sf_reuse == 0, ( - f"alloc_sf: shape last dim {last} must be divisible by sf_reuse={sf_reuse}" - ) - SF_K = last // sf_reuse - layout = sf_tmem_layout( - rows=rows, SF_K=SF_K, sf_per_mma=sf_per_mma, sf_reuse=sf_reuse, pipe_depth=pipe_depth - ) - return self.alloc(shape, dtype, layout=layout) - - def move_base_to(self, col): - self.offset = col - self.max_offset = max(self.max_offset, self.offset) - - def commit(self): - assert not self._committed, "TMEMPool.commit() can only be called once" - from tvm.script import tirx as T - - def emit_alloc(): - _emit_stmt( - T.ptx.tcgen05.alloc( - T.address_of(self.addr), n_cols=self.total_cols, cta_group=self.cta_group - ) - ) - if self.sync_after_alloc: - _emit_stmt(T.cuda.warp_sync()) - - self._emit_warp_guard(self.alloc_warp, emit_alloc) - self._committed = True - - def dealloc(self): - assert self._committed, "TMEMPool.dealloc() called before commit()" - assert not self._deallocated, "TMEMPool.dealloc() can only be called once" - self._deallocated = True - from tvm.script import tirx as T - - def emit_dealloc(): - _emit_stmt(T.ptx.tcgen05.relinquish_alloc_permit(cta_group=self.cta_group)) - _emit_stmt( - T.ptx.tcgen05.dealloc(self.addr, n_cols=self.total_cols, cta_group=self.cta_group) - ) - - self._emit_warp_guard(self.dealloc_warp, emit_dealloc) - - -# --------------------------------------------------------------------------- -# SMEMPool -# --------------------------------------------------------------------------- - - -@_meta_class -class SMEMPool: - """Bump allocator over a contiguous shared memory region. - - Parameters - ---------- - ptr : Var or None, optional - If omitted, an ``alloc_buffer([0], "uint8", scope="shared.dyn")`` is - created automatically and ``commit()`` must be called after all - allocations to emit the size annotation. - If a ``Var`` is provided, the caller manages the backing buffer and - ``commit()`` is a no-op. - """ - - def __init__(self, ptr=_POOL_UNSET): - ir = _get_ir() - if ptr is _POOL_UNSET: - self.buf = ir.alloc_buffer([0], "uint8", scope="shared.dyn") - self.ptr = self.buf.data - self._owns_buffer = True - else: - self.buf = None - self.ptr = ptr - self._owns_buffer = False - self.offset = 0 - self.max_offset = 0 - - def alloc( - self, - shape, - dtype="float32", - strides=None, - scope="shared.dyn", - align=0, - buffer_type="", - axis_separators=None, - layout="default", - ): - ir = _get_ir() - if align > 0: - self.offset = (self.offset + align - 1) // align * align - res = ir.decl_buffer( - shape, - dtype, - data=self.ptr, - strides=strides, - byte_offset=self.offset, - scope=scope, - align=align, - buffer_type=buffer_type, - axis_separators=axis_separators, - layout=layout, - ) - # Advance in bits then round up to bytes so sub-byte dtypes (e.g. - # float4_e2m1fn = 4 bits) still bump the cursor instead of leaving it - # at 0 (bits // 8) and silently overlapping the next allocation. - self.offset += (_shape_product(shape) * DataType(dtype).bits + 7) // 8 - if self._owns_buffer: - self.max_offset = max(self.max_offset, self.offset) - return res - - def alloc_mma(self, shape, dtype="float16", swizzle_mode="auto", align=1024): - """Allocate MMA-compatible shared memory with an inferred swizzle layout.""" - from tvm.tirx.operator.tile_primitive.cuda.tma_utils import ( - SwizzleMode, - mma_shared_layout, - ) - - if isinstance(swizzle_mode, str): - if swizzle_mode == "auto": - swizzle_mode = _auto_swizzle_mode(dtype) - elif swizzle_mode == "none": - swizzle_mode = SwizzleMode.SWIZZLE_NONE - else: - raise ValueError( - f"Unsupported swizzle_mode={swizzle_mode!r}; expected 'auto', 'none', " - "or SwizzleMode" - ) - _validate_mma_alloc_shape(shape, dtype, swizzle_mode) - layout = mma_shared_layout(dtype, swizzle_mode, shape) - return self.alloc(shape, dtype, align=align, layout=layout) - - def move_base_to(self, offset): - self.offset = offset - if self._owns_buffer: - self.max_offset = max(self.max_offset, self.offset) - - def commit(self, size=None): - """Emit pool size annotation into the IR. - - Must be called after all ``alloc()`` / ``move_base_to()`` calls. - - Parameters - ---------- - size : int, optional - Explicit shared memory size in bytes. When *None* (the default), - the high-water mark ``max_offset`` tracked by the allocator is used. - """ - if not self._owns_buffer: - return - ir = _get_ir() - frame_mod = _get_frame() - resolved = size if size is not None else self.max_offset - assert resolved >= self.max_offset, ( - f"Specified smem size ({resolved}) is smaller than " - f"the pool high-water mark ({self.max_offset})" - ) - attr_frame = ir.attr(self.ptr, "tirx.pool_max_bytes", resolved) - if isinstance(attr_frame, frame_mod.AttrFrame): - from functools import partial - - attr_frame.add_callback(partial(attr_frame.__exit__, None, None, None)) - attr_frame.__enter__() +from tvm.backend.cuda.lang.alloc_pool import * # noqa: F403 # pylint: disable=wildcard-import diff --git a/python/tvm/tirx/lang/pipeline.py b/python/tvm/tirx/lang/pipeline.py index ee86090398e9..ee6380f81806 100644 --- a/python/tvm/tirx/lang/pipeline.py +++ b/python/tvm/tirx/lang/pipeline.py @@ -2,243 +2,7 @@ # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Reusable pipeline state and mbarrier helpers for SM100 kernels. +# to you under the Apache License, Version 2.0. +"""Compatibility redirect for CUDA pipeline helpers.""" -These classes emit TIR via @T.inline. Decorate with @T.meta_class so that -instances are automatically treated as meta values inside @T.prim_func. -""" - -from tvm.script import tirx as T - - -@T.meta_class -class PipelineState: - """Tracks stage and phase for a software-pipelined ring buffer. - - This class does not know anything about full/empty barriers. Use it when - the kernel manually waits/signals barriers, or when the stage/phase drives - a ring not wrapped in a ``Pipeline``. - - Parameters - ---------- - depth : int - Number of stages in the ring. - phase : int, optional - Initial phase. Omit when initialization should happen later. - """ - - def __init__(self, depth: int, phase=None): - self.stage = T.local_scalar("int32") - self.phase = T.local_scalar("int32") - self.depth = depth - if phase is not None: - self.init(phase) - - @T.inline - def init(self, phase): - self.stage = 0 - self.phase = phase - - @T.inline - def advance(self): - if self.depth > 1: - self.stage = self.stage + 1 - if self.stage == self.depth: - self.stage = 0 - self.phase = self.phase ^ 1 - else: - self.phase = self.phase ^ 1 - - -@T.meta_class -class MBarrier: - """Mbarrier wrapper with regular ``mbarrier.arrive``. - - Parameters - ---------- - pool : SMEMPool - Shared memory pool allocator. - depth : int - Number of barrier slots (one per pipeline stage). - phase_offset : int - XORed into the phase bit on every ``wait`` / ``arrive``. - leader : PrimExpr, optional - Boolean predicate selecting the single thread that runs - ``mbarrier.init``. Defaults to ``T.cuda.thread_rank() == 0`` -- - thread 0 of the enclosing CTA, which always picks exactly one - thread regardless of which scope_id vars the caller declared. - Override only when you want a different CTA-local thread to do - the init. - - Note: the default deliberately avoids ``T.warp_id()`` / - ``T.lane_id()``. Those introduce deferred ``cta->warp`` / - ``warp->thread`` ScopeIdDefs that the verifier cannot pin down - unless the kernel header declares the full warp/lane chain (e.g. a - single-CTA DSMEM kernel that only declares ``thread_id``). It also - avoids the synccheck false-deadlock on kernels that declare a - second warp-scope id. The generated CUDA is equivalent. - """ - - def __init__(self, pool, depth, phase_offset=0, leader=None): - self.buf = pool.alloc((depth,), "uint64", align=8) - self.depth = depth - self.phase_offset = phase_offset - self.leader = leader if leader is not None else (T.cuda.thread_rank() == 0) - - @T.inline - def init(self, count): - if self.leader: - for i in T.unroll(self.depth): - T.ptx.mbarrier.init(self.buf.ptr_to([i]), count) - - @T.inline - def wait(self, stage, phase): - # Blocks: ``mbarrier.try_wait`` loops internally until the phase flips, - # so this returns only once the barrier has completed. - T.ptx.mbarrier.try_wait(self.buf.ptr_to([stage]), phase ^ self.phase_offset) - - @T.inline - def arrive(self, stage, cta_id=None, pred=None): - # Default: local-CTA arrive — emits the simple - # ``mbarrier.arrive.shared.b64`` form. To arrive on a remote - # CTA's mbarrier in a cluster kernel, callers must pass - # ``cta_id=`` explicitly (e.g. ``bar.arrive(stage, cta_id=0)``) - # or use ``MBarrier.remote_view(rank).arrive(stage)``. Defaulting - # the cross-CTA path was both surprising (``bar.arrive(stage)`` - # silently ``mapa`` ed across the cluster) and a per-call cost - # of ~3 PTX ops on every single-CTA kernel. - if cta_id is None: - T.ptx.mbarrier.arrive(self.buf.ptr_to([stage])) - else: - actual_pred = True if pred is None else pred - T.ptx.mbarrier.arrive(self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred) - - def ptr_to(self, idx): - return self.buf.ptr_to(idx) - - def remote_view(self, rank): - """Create a view of this barrier mapped to another CTA's shared memory. - - Arrive-only: the returned view is built with ``object.__new__`` and - never copies ``self.leader``, so calling ``.init()`` on it would fail. - Use it solely to ``arrive`` on a remote CTA's mbarrier. - """ - from tvm.ir import PointerType, PrimType - from tvm.tirx import Var as TIRVar - - expr = T.reinterpret("handle", T.ptx.map_shared_rank(self.buf.ptr_to([0]), rank)) - ptr = TIRVar("remote_mbar_ptr", PointerType(PrimType("uint64"))) - T.Bind(expr, var=ptr) - buf = T.decl_buffer([self.depth], "uint64", data=ptr, scope="shared") - remote = object.__new__(type(self)) - remote.buf = buf - remote.depth = self.depth - remote.phase_offset = self.phase_offset - return remote - - -class TMABar(MBarrier): - """Barrier signaled by TMA (mbarrier.arrive.expect_tx). - - When ``tx_count`` is None, falls back to a remote mbarrier.arrive - (matching MBarrier.arrive defaults). - """ - - @T.inline - def arrive(self, stage, tx_count=None, cta_id=None, pred=None): - # NOTE: this arrive() kwarg set intentionally differs from - # MBarrier.arrive (hardware necessity, LSP-incompatible by design). - # ``tx_count``: TMA byte count for ``mbarrier.arrive.expect_tx``. - # ``cta_id`` / ``pred``: forwarded to the underlying - # ``mbarrier.arrive`` (cluster path) when set; otherwise the - # arrive is local-CTA only. See ``MBarrier.arrive`` for the - # full default-local rationale. - if tx_count is not None: - T.ptx.mbarrier.arrive.expect_tx(self.buf.ptr_to([stage]), tx_count) - elif cta_id is None: - T.ptx.mbarrier.arrive(self.buf.ptr_to([stage])) - else: - actual_pred = True if pred is None else pred - T.ptx.mbarrier.arrive(self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred) - - -class TCGen05Bar(MBarrier): - """Barrier signaled by ``tcgen05`` commit. - - The caller is responsible for ensuring only one thread issues the - commit, e.g. by wrapping the call in ``if T.ptx.elect_sync():``. - """ - - @T.inline - def arrive(self, stage, cta_group=1, cta_mask=None): - # NOTE: this arrive() kwarg set intentionally differs from - # MBarrier.arrive (hardware necessity, LSP-incompatible by design). - if cta_mask is None and cta_group == 1: - T.ptx.tcgen05.commit(self.buf.ptr_to([stage])) - else: - T.ptx.tcgen05.commit(self.buf.ptr_to([stage]), cta_group=cta_group, cta_mask=cta_mask) - - -# Barrier-type tags accepted by Pipeline's ``full=`` / ``empty=`` arguments. -_BAR_KINDS = {"tma": TMABar, "tcgen05": TCGen05Bar, "mbar": MBarrier} - - -@T.meta_class -class Pipeline: - """A full/empty mbarrier pair for a software-pipelined data flow. - - Pass barrier-type tags and ``Pipeline`` constructs and ``init``\\ s the - barriers itself. Tags: ``"tma"`` (TMABar), ``"tcgen05"`` (TCGen05Bar), - ``"mbar"`` (MBarrier). The barrier type and arrival count of each event - stay explicit at the call site -- e.g. ``Pipeline(pool, n, full="tma", - empty="tcgen05", init_empty=NUM_CONSUMER)``. - - Both signals are required: a ``Pipeline`` is a *pair*. For a one-way event - (a pure "X happened" signal with no slot to recycle) use a bare barrier - (``TMABar``/``TCGen05Bar``/``MBarrier``) directly -- it has no empty side. - - Parameters - ---------- - pool : SMEMPool - Shared memory pool allocator. - stages : int - Number of pipeline stages (barrier slots). - full, empty : str - Barrier-type tag for the full / empty signal (see above). - init_full, init_empty : int - Expected arrival count for the full / empty barrier. - empty_phase_offset : int - XORed into the empty barrier's phase bit on every wait / arrive. - leader : PrimExpr, optional - Propagated to both barriers; defaults to thread 0 of the CTA. - """ - - def __init__( - self, - pool, - stages, - *, - full, - empty, - init_full=1, - init_empty=1, - empty_phase_offset=0, - leader=None, - ): - self.stages = stages - self.full = _BAR_KINDS[full](pool, stages, leader=leader) - self.full.init(init_full) - self.empty = _BAR_KINDS[empty](pool, stages, phase_offset=empty_phase_offset, leader=leader) - self.empty.init(init_empty) +from tvm.backend.cuda.lang.pipeline import * # noqa: F403 # pylint: disable=wildcard-import diff --git a/python/tvm/tirx/lang/smem_desc.py b/python/tvm/tirx/lang/smem_desc.py index c858cb70690c..d4a5f9686dc2 100644 --- a/python/tvm/tirx/lang/smem_desc.py +++ b/python/tvm/tirx/lang/smem_desc.py @@ -2,54 +2,7 @@ # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. +# to you under the Apache License, Version 2.0. +"""Compatibility redirect for CUDA shared-memory descriptors.""" -"""SMEM matrix descriptor helper for tcgen05 / wgmma.""" - -from tvm.script import tirx as T -from tvm.tirx.operator.tile_primitive.cuda.common import smem_desc_add_16B_offset - - -@T.meta_class -class SmemDescriptor: - """Encoded once via :meth:`init`, reused via :meth:`add_16B_offset`.""" - - def __init__(self): - self._buf = T.alloc_local([1], "uint64") - - @property - def desc(self): - return self._buf[0] - - @T.inline - def init(self, smem_ptr, ldo, sdo, swizzle): - T.ptx.tcgen05.encode_matrix_descriptor( - T.address_of(self._buf[0]), smem_ptr, ldo, sdo, swizzle - ) - - def add_16B_offset(self, offset): - return smem_desc_add_16B_offset(self._buf[0], offset) - - def make_lo_uniform(self): - """Broadcast the lower 32 bits to all warp lanes via ``__shfl_sync``.""" - func_name = "smem_desc_make_lo_uniform" - source_code = f""" -__forceinline__ __device__ void {func_name}(uint64_t* desc) {{ - SmemDescriptor* d = reinterpret_cast(desc); - d->lo = __shfl_sync(0xffffffff, d->lo, 0); -}} -""" - return T.cuda.func_call( - func_name, T.address_of(self._buf[0]), source_code=source_code, return_type="void" - ) +from tvm.backend.cuda.lang.smem_desc import * # noqa: F403 # pylint: disable=wildcard-import diff --git a/python/tvm/tirx/lang/tile_scheduler.py b/python/tvm/tirx/lang/tile_scheduler.py index 3fd27f25ee5f..145b17581cbf 100644 --- a/python/tvm/tirx/lang/tile_scheduler.py +++ b/python/tvm/tirx/lang/tile_scheduler.py @@ -2,815 +2,7 @@ # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Reusable tile scheduler helpers for TIR tests/kernels. +# to you under the Apache License, Version 2.0. +"""Compatibility redirect for CUDA tile schedulers.""" -These classes emit TIR via @T.inline. Decorate with @T.meta_class so that -instances are automatically treated as meta values inside @T.prim_func. -""" - -from tvm.script import tirx as T - - -@T.meta_class -class BaseTileScheduler: - """Base class for tile schedulers with common state and macros.""" - - def __init__(self, prefix: str): - self.m_idx = T.local_scalar("int32") - self.n_idx = T.local_scalar("int32") - self.linear_idx = T.local_scalar("int32") - - @T.inline - def update_current_m_n_idx(self, linear_idx): - # To be implemented by subclasses - pass - - @T.inline - def init(self, linear_init): - self.linear_idx = linear_init - self.update_current_m_n_idx(linear_init) - - @T.inline - def next_tile(self, step): - self.linear_idx = self.linear_idx + step - self.update_current_m_n_idx(self.linear_idx) - - def valid(self, total_tiles): - return self.linear_idx < total_tiles - - -class ClusterPersistentScheduler2D(BaseTileScheduler): - """ - Tile scheduler for cluster-based persistent kernels. - - Distributes a 2D tile grid across persistent clusters using group-major ordering - for L2 cache locality. Each cluster starts at its cluster_id and strides by - num_clusters to process tiles. - - Tile Ordering (group-major for L2 locality): - - Tiles are grouped into "L2 groups" of `l2_group_size` rows - - Within a group, tiles are visited in column-major order within the group - - Groups are processed in row-major order - - Example with 4x4 tiles, l2_group_size=2: - Group 0 (rows 0-1): 0 2 4 6 - 1 3 5 7 - Group 1 (rows 2-3): 8 10 12 14 - 9 11 13 15 - - Serpentine Mode (serpentine=True): - - Uses CUTLASS-style 2D block swizzle with serpentine traversal - - Grid is divided into swizzle_size x swizzle_size blocks - - Within each block, tiles are visited in row-major order - - Blocks are traversed in serpentine order (even block-rows forward, odd backward) - - This provides better L2 locality by reusing both A and B tiles - - Example with 4x4 tiles, swizzle_size=2, serpentine=True: - Block layout: - Block(0,0) Block(0,1) - Block(1,0) Block(1,1) - - Tile numbering with serpentine: - n=0 n=1 n=2 n=3 - m=0 0 1 14 15 - m=1 2 3 12 13 - m=2 4 5 10 11 - m=3 6 7 8 9 - - Traversal: Block(0,0) -> Block(1,0) -> Block(1,1) -> Block(0,1) - (serpentine: down in col 0, then up in col 1) - - Parameters - ---------- - prefix : str - Prefix for TIR variable names - num_m_tiles : int | T.ExprLike - Total number of tiles in M dimension (can be runtime expression) - num_n_tiles : int - Total number of tiles in N dimension - num_clusters : int - Number of persistent clusters (determines stride) - l2_group_size : int - Number of M-tile rows per L2 locality group (default: 8) - When serpentine=True, this is used as swizzle_size for 2D blocks - cluster_m : int - Cluster dimension in M for hierarchical scheduling (default: 1) - cluster_n : int - Cluster dimension in N for hierarchical scheduling (default: 1) - serpentine : bool - If True, use CUTLASS-style 2D block swizzle with serpentine traversal (default: False) - - Attributes - ---------- - m_idx : T.local_scalar - Current M tile index (output) - n_idx : T.local_scalar - Current N tile index (output) - work_idx : T.local_scalar - Global work item index for this cluster - tile_count : T.local_scalar - Number of tiles processed by this cluster so far - - Usage - ----- - ```python - scheduler = ClusterPersistentScheduler2D( - "sched", num_m_tiles=M_TILES, num_n_tiles=N_TILES, - num_clusters=NUM_CLUSTERS, l2_group_size=8 - ) - scheduler.init(cluster_id) # cluster_id = cta_idx // CLUSTER_SIZE - - while scheduler.valid(): - m = T.meta_var(scheduler.m_idx) # current M tile - n = T.meta_var(scheduler.n_idx) # current N tile - # ... process tile (m, n) ... - scheduler.next_tile() - ``` - - Examples - -------- - Example 1: Basic persistent kernel - ``` - num_m_tiles=4, num_n_tiles=4, num_clusters=3, l2_group_size=2 - cluster_m=1, cluster_n=1 (default, no tile subdivision) - - Group-major tile numbering (l2_group_size=2): - n=0 n=1 n=2 n=3 - m=0 0 2 4 6 ┐ L2 group 0 - m=1 1 3 5 7 ┘ - m=2 8 10 12 14 ┐ L2 group 1 - m=3 9 11 13 15 ┘ - - Work distribution (cluster starts at cluster_id, strides by num_clusters=3): - cluster 0: work_idx 0,3,6,9,12,15 -> tiles 0,3,6,9,12,15 - cluster 1: work_idx 1,4,7,10,13 -> tiles 1,4,7,10,13 - cluster 2: work_idx 2,5,8,11,14 -> tiles 2,5,8,11,14 - - Tile grid (which cluster handles each tile): - n=0 n=1 n=2 n=3 - m=0 C0 C2 C1 C0 ┐ L2 group 0 - m=1 C1 C0 C2 C1 ┘ - m=2 C2 C1 C0 C2 ┐ L2 group 1 - m=3 C0 C2 C1 C0 ┘ - - Tile sequence per cluster (in execution order): - cluster 0: (0,0)->(1,1)->(0,3)->(2,0)->(2,3)->(3,3) - cluster 1: (1,0)->(0,2)->(1,3)->(2,1)->(3,2) - cluster 2: (0,1)->(1,2)->(2,0)->(3,1)->(2,3) - ``` - - Example 2: 2SM GEMM (typical B200 config) - ``` - M=1024, N=512, CTA_M=128, MMA_N=128, CLUSTER_M=2, CLUSTER_N=1 - => M_TILES=8, N_TILES=4 - => CLUSTER_M_TILES=4, CLUSTER_N_TILES=4 (scheduler at cluster granularity) - - Scheduler params: - num_m_tiles=4, num_n_tiles=4, num_clusters=74, l2_group_size=8 - cluster_m=1, cluster_n=1 - - Key: Scheduler outputs CLUSTER-level tiles. - All CTAs in same cluster get SAME (m_idx, n_idx) from scheduler. - CTAs differentiate via cluster_rank (computed OUTSIDE scheduler): - cluster_rank = cta_idx % CLUSTER_SIZE - cb_m = cluster_rank % CLUSTER_M # 0 or 1 for 2SM - cb_n = cluster_rank // CLUSTER_M # 0 for 2SM - - Final CTA tile: - cta_m = m_idx * CLUSTER_M + cb_m - cta_n = n_idx * CLUSTER_N + cb_n - - Example: cluster 5 gets scheduler tile (1,2) - CTA rank=0 (cb_m=0): actual tile (2,2) - CTA rank=1 (cb_m=1): actual tile (3,2) - ``` - """ - - def __init__( - self, - prefix: str, - num_m_tiles, - num_n_tiles: int, - num_clusters: int, - l2_group_size: int = 8, - cluster_m: int = 1, - cluster_n: int = 1, - serpentine: bool = False, - ): - super().__init__(prefix) - self._num_m_tiles = num_m_tiles - self._num_n_tiles = num_n_tiles - self._num_clusters = num_clusters - self._l2_group_size = l2_group_size - self._cluster_m = cluster_m - self._cluster_n = cluster_n - self._serpentine = serpentine - - # Rename internal state for clarity - self.work_idx = self.linear_idx # alias: global work item index - self.tile_count = T.local_scalar("int32") - self.tile_idx = self.tile_count # alias for backward compatibility - - is_static_m = isinstance(num_m_tiles, int) - - # Number of tile columns after accounting for cluster_n - n_tile_cols = (num_n_tiles + cluster_n - 1) // cluster_n - self._N_TILE_COLS = n_tile_cols - - if is_static_m: - self._M_TILE_ROWS = (num_m_tiles + cluster_m - 1) // cluster_m - self._FULL_GROUPS = self._M_TILE_ROWS // l2_group_size - else: - # Dynamic expressions for runtime M - self._M_TILE_ROWS = T.truncdiv(self._num_m_tiles + self._cluster_m - 1, self._cluster_m) - self._FULL_GROUPS = T.truncdiv(self._M_TILE_ROWS, self._l2_group_size) - - self._TAIL_ROWS = self._M_TILE_ROWS - self._FULL_GROUPS * l2_group_size - self._TOTAL_TILES = self._M_TILE_ROWS * n_tile_cols * cluster_m * cluster_n - - # For serpentine mode: precompute block counts - if serpentine: - self._N_BLOCKS = n_tile_cols // l2_group_size # full blocks in N - self._M_BLOCKS = ( - self._M_TILE_ROWS // l2_group_size - if is_static_m - else T.truncdiv(self._M_TILE_ROWS, l2_group_size) - ) - self._BLOCK_SIZE = l2_group_size * l2_group_size # tiles per block - self._FULL_BLOCK_TILES = self._M_BLOCKS * self._N_BLOCKS * self._BLOCK_SIZE - # Residual tiles (not covered by full blocks) - self._RESIDUAL_N = n_tile_cols - self._N_BLOCKS * l2_group_size - self._RESIDUAL_M = self._M_TILE_ROWS - self._M_BLOCKS * l2_group_size - - # fmt: off - @T.inline - def update_current_m_n_idx(self, work_idx): - """Convert global work index to (m_idx, n_idx) tile coordinates.""" - CLUSTER_M = T.meta_var(self._cluster_m) - CLUSTER_N = T.meta_var(self._cluster_n) - - # Extract hierarchical cluster-local offsets - cluster_m_offset = T.meta_var(work_idx % CLUSTER_M) - t = T.meta_var(work_idx // CLUSTER_M) - cluster_n_offset = T.meta_var(t % CLUSTER_N) - tile_linear = T.meta_var(t // CLUSTER_N) - - @T.inline - def set_tile_coords(tile_row, tile_col): - self.m_idx = tile_row * CLUSTER_M + cluster_m_offset - self.n_idx = tile_col * CLUSTER_N + cluster_n_offset - - if self._serpentine: - self._update_serpentine(tile_linear, set_tile_coords) - else: - self._update_group_major(tile_linear, set_tile_coords) - - def _update_group_major(self, tile_linear, set_tile_coords): - """Group-major ordering with parse-time pruning of statically-dead branches. - - The TIR script parser does not constant-fold ``if False: ...``, so a - Python-literal ``FULL_GROUPS == 0`` would otherwise produce - ``T.bitwise_and(T.bool(False), tile_linear < 0)`` IR plus the dead - then-leg. Branch in plain Python here and only invoke the inline - emitter that can actually fire. - """ - full_zero = isinstance(self._FULL_GROUPS, int) and self._FULL_GROUPS == 0 - tail_zero = isinstance(self._TAIL_ROWS, int) and self._TAIL_ROWS == 0 - if full_zero and tail_zero: - self._gm_emit_zero(set_tile_coords) - elif full_zero: - self._gm_emit_tail_only(tile_linear, set_tile_coords) - elif tail_zero: - self._gm_emit_full_only(tile_linear, set_tile_coords) - else: - self._gm_emit_full_and_tail(tile_linear, set_tile_coords) - - @T.inline - def _gm_emit_zero(self, set_tile_coords): - set_tile_coords(0, 0) - - @T.inline - def _gm_emit_full_only(self, tile_linear, set_tile_coords): - FULL_GROUPS = T.meta_var(self._FULL_GROUPS) - GROUP_SIZE = T.meta_var(self._l2_group_size) - GROUP_SPAN = T.meta_var(self._l2_group_size * self._N_TILE_COLS) - if (FULL_GROUPS > 0) & (tile_linear < FULL_GROUPS * GROUP_SPAN): - group_id: T.let = tile_linear // GROUP_SPAN - within_group: T.let = tile_linear % GROUP_SPAN - tile_row: T.let = group_id * GROUP_SIZE + (within_group % GROUP_SIZE) - tile_col: T.let = within_group // GROUP_SIZE - set_tile_coords(tile_row, tile_col) - else: - set_tile_coords(0, 0) - - @T.inline - def _gm_emit_tail_only(self, tile_linear, set_tile_coords): - FULL_GROUPS = T.meta_var(self._FULL_GROUPS) - TAIL_ROWS = T.meta_var(self._TAIL_ROWS) - GROUP_SIZE = T.meta_var(self._l2_group_size) - GROUP_SPAN = T.meta_var(self._l2_group_size * self._N_TILE_COLS) - if TAIL_ROWS > 0: - rem: T.let = tile_linear - FULL_GROUPS * GROUP_SPAN - tile_row: T.let = FULL_GROUPS * GROUP_SIZE + (rem % TAIL_ROWS) - tile_col: T.let = rem // TAIL_ROWS - set_tile_coords(tile_row, tile_col) - else: - set_tile_coords(0, 0) - - @T.inline - def _gm_emit_full_and_tail(self, tile_linear, set_tile_coords): - FULL_GROUPS = T.meta_var(self._FULL_GROUPS) - TAIL_ROWS = T.meta_var(self._TAIL_ROWS) - GROUP_SIZE = T.meta_var(self._l2_group_size) - GROUP_SPAN = T.meta_var(self._l2_group_size * self._N_TILE_COLS) - if (FULL_GROUPS > 0) & (tile_linear < FULL_GROUPS * GROUP_SPAN): - group_id: T.let = tile_linear // GROUP_SPAN - within_group: T.let = tile_linear % GROUP_SPAN - tile_row: T.let = group_id * GROUP_SIZE + (within_group % GROUP_SIZE) - tile_col: T.let = within_group // GROUP_SIZE - set_tile_coords(tile_row, tile_col) - elif TAIL_ROWS > 0: - rem: T.let = tile_linear - FULL_GROUPS * GROUP_SPAN - tile_row: T.let = FULL_GROUPS * GROUP_SIZE + (rem % TAIL_ROWS) - tile_col: T.let = rem // TAIL_ROWS - set_tile_coords(tile_row, tile_col) - else: - set_tile_coords(0, 0) - - @T.inline - def _update_serpentine(self, tile_linear, set_tile_coords): - """CUTLASS-style 2D block swizzle with serpentine traversal. - - Algorithm: - 1. Divide grid into swizzle_size x swizzle_size blocks - 2. Within each block, visit tiles in row-major order - 3. Blocks are traversed column by column (along N) - 4. Within each column of blocks, use serpentine: - - Even columns: top to bottom - - Odd columns: bottom to top - - This maximizes L2 reuse for both A and B matrices. - """ - S = T.meta_var(self._l2_group_size) # swizzle_size - M_BLOCKS = T.meta_var(self._M_BLOCKS) - N_BLOCKS = T.meta_var(self._N_BLOCKS) - BLOCK_SIZE = T.meta_var(self._BLOCK_SIZE) # S * S - FULL_BLOCK_TILES = T.meta_var(self._FULL_BLOCK_TILES) - M_TILE_ROWS = T.meta_var(self._M_TILE_ROWS) - T.meta_var(self._N_TILE_COLS) - RESIDUAL_N = T.meta_var(self._RESIDUAL_N) - RESIDUAL_M = T.meta_var(self._RESIDUAL_M) - - # Check if we're in the full block region - if (M_BLOCKS > 0) & (N_BLOCKS > 0) & (tile_linear < FULL_BLOCK_TILES): - # Which block (in linear order along columns of blocks) - block_linear: T.let = tile_linear // BLOCK_SIZE - within_block: T.let = tile_linear % BLOCK_SIZE - - # Block column and row - block_col: T.let = block_linear // M_BLOCKS - block_row_raw: T.let = block_linear % M_BLOCKS - - # Serpentine: odd columns go bottom-to-top - block_row: T.let = T.Select( - block_col % 2 == 0, - block_row_raw, - M_BLOCKS - 1 - block_row_raw - ) - - # Position within block (row-major within block) - local_row: T.let = within_block // S - local_col: T.let = within_block % S - - tile_row: T.let = block_row * S + local_row - tile_col: T.let = block_col * S + local_col - set_tile_coords(tile_row, tile_col) - - elif RESIDUAL_N > 0: - # Residual tiles in the rightmost partial column of blocks - # These are tiles where n >= N_BLOCKS * S - rem: T.let = tile_linear - FULL_BLOCK_TILES - - # First handle the right residual strip (full M height, partial N width) - right_strip_tiles: T.let = M_TILE_ROWS * RESIDUAL_N - if rem < right_strip_tiles: - # Row-major within the right strip - tile_row: T.let = rem // RESIDUAL_N - tile_col: T.let = N_BLOCKS * S + (rem % RESIDUAL_N) - set_tile_coords(tile_row, tile_col) - elif RESIDUAL_M > 0: - # Bottom residual strip (already covered in right strip overlap) - # This handles corner case - shouldn't normally reach here - # as right strip already covers full M height - set_tile_coords(0, 0) - else: - set_tile_coords(0, 0) - - elif RESIDUAL_M > 0: - # Bottom residual strip only (no right residual) - rem: T.let = tile_linear - FULL_BLOCK_TILES - bottom_strip_tiles: T.let = RESIDUAL_M * (N_BLOCKS * S) - if rem < bottom_strip_tiles: - tile_row: T.let = M_BLOCKS * S + (rem % RESIDUAL_M) - tile_col: T.let = rem // RESIDUAL_M - set_tile_coords(tile_row, tile_col) - else: - set_tile_coords(0, 0) - else: - # Fallback - set_tile_coords(0, 0) - - @T.inline - def init(self, cluster_id): - """Initialize scheduler for a given cluster. - - Parameters - ---------- - cluster_id : int - The cluster's index (typically cta_idx // CLUSTER_SIZE) - """ - self.linear_idx = cluster_id - self.tile_count = 0 - self.update_current_m_n_idx(cluster_id) - - @T.inline - def next_tile(self): - """Advance to the next tile for this cluster.""" - self.linear_idx = self.linear_idx + self._num_clusters - self.tile_count = self.tile_count + 1 - self.update_current_m_n_idx(self.linear_idx) - - @T.inline - def next_tile_stride(self, stride: int): - """Advance by a custom stride (for non-standard scheduling).""" - self.linear_idx = self.linear_idx + stride - self.tile_count = self.tile_count + 1 - self.update_current_m_n_idx(self.linear_idx) - # fmt: on - - def valid(self): - """Check if this cluster has more tiles to process.""" - return self.linear_idx < self._TOTAL_TILES - - -class GroupMajor3D(BaseTileScheduler): - """ - 3D grouped-row scheduler (M,N,K) with tail handling on M. - - Args - ---- - prefix: str - m_tiles: int | T PrimExpr # tiles along M (static or runtime) - n_tiles: int # tiles along N (static) - k_tiles: int # tiles along K (static) - group_rows: int # rows per group along M - step: int = 1 # default stride for next_tile() - """ - - def __init__( - self, prefix: str, m_tiles, n_tiles: int, k_tiles: int, group_rows: int, step: int = 1 - ): - super().__init__(prefix) - self._step = step - self.tile_idx = T.local_scalar("int32") - self.k_idx = T.local_scalar("int32") - - # ---- constants / primexprs baked once ---- - self._G = group_rows - self._N = n_tiles - self._K = k_tiles - - if isinstance(m_tiles, int): - self._GROUPS = m_tiles // group_rows - self._FINAL_ROWS = m_tiles - self._GROUPS * group_rows - self._SAFE_FINAL_ROWS = max(self._FINAL_ROWS, 1) - self._GROUP_SIZE = group_rows * n_tiles * k_tiles - self._TOTAL = m_tiles * n_tiles * k_tiles - else: - self._GROUPS = T.truncdiv(m_tiles, group_rows) - self._FINAL_ROWS = m_tiles - self._GROUPS * group_rows - self._SAFE_FINAL_ROWS = T.max(self._FINAL_ROWS, 1) - self._GROUP_SIZE = self._G * self._N * self._K - self._TOTAL = m_tiles * n_tiles * k_tiles - - # handy composites used in macro - self._FULL_BOUND = self._GROUPS * self._GROUP_SIZE - self._HAS_FULL = self._GROUPS > 0 - self._HAS_TAIL = self._FINAL_ROWS > 0 - - # fmt: off - @T.inline - def update_current_m_n_idx(self, linear_idx): - # full-group formulas - full_m: T.let = T.floordiv(linear_idx, self._GROUP_SIZE) * self._G + T.floormod( - linear_idx, self._G - ) - full_n: T.let = T.floormod(T.floordiv(linear_idx, self._G), self._N) - full_k: T.let = T.floordiv(T.floormod(linear_idx, self._GROUP_SIZE), self._G * self._N) - - # tail formulas (relative to FULL_BOUND) - # Use _SAFE_FINAL_ROWS (max(FINAL_ROWS, 1)) to avoid divide-by-zero when there is no tail - rem: T.let = linear_idx - self._FULL_BOUND - tail_m: T.let = self._GROUPS * self._G + T.floormod(rem, self._SAFE_FINAL_ROWS) - tail_n: T.let = T.floordiv(rem, self._SAFE_FINAL_ROWS) % self._N - tail_k: T.let = T.floordiv(rem, self._SAFE_FINAL_ROWS * self._N) - - # choose phase - if self._HAS_FULL & (linear_idx < self._FULL_BOUND): - self.m_idx = full_m - self.n_idx = full_n - self.k_idx = full_k - elif self._HAS_TAIL: - self.m_idx = tail_m - self.n_idx = tail_n - self.k_idx = tail_k - else: - self.m_idx = 0 - self.n_idx = 0 - self.k_idx = 0 - - @T.inline - def init(self, linear_init): - self.linear_idx = linear_init - self.tile_idx = 0 - self.update_current_m_n_idx(linear_init) - - @T.inline - def next_tile(self): - self.linear_idx = self.linear_idx + self._step - self.tile_idx = self.tile_idx + 1 - self.update_current_m_n_idx(self.linear_idx) - - @T.inline - def next_tile_stride(self, stride: int): - self.linear_idx = self.linear_idx + stride - self.tile_idx = self.tile_idx + 1 - self.update_current_m_n_idx(self.linear_idx) - # fmt: on - - def valid(self): - return self.linear_idx < self._TOTAL - - -class RankAwareGroupMajorTileScheduler(BaseTileScheduler): - """ - Group-major scheduler that applies a rank-aware remapping (remote rows first). - Kept as a thin adapter because it depends on NVSHMEM rank at device-side. - """ - - def __init__( - self, prefix: str, m_clusters: int, n_clusters: int, group_size: int, world_size: int - ): - super().__init__(prefix) - self._m_clusters = m_clusters - self._n_clusters = n_clusters - self._group_size = group_size - self._world_size = world_size - - @T.inline - def update_current_m_n_idx(self, linear_idx): - my_rank: T.let = T.nvshmem.my_pe() - remote_m_clusters: T.let = self._m_clusters - self._m_clusters // self._world_size - group_rows: T.let = (remote_m_clusters // self._group_size) * self._group_size - final_rows: T.let = remote_m_clusters - group_rows - group_repeat: T.let = self._group_size * self._n_clusters - if linear_idx < group_rows * self._n_clusters and group_rows > 0: - self.m_idx = ( - (linear_idx // group_repeat) * self._group_size - + (linear_idx % self._group_size) - + (my_rank + 1) * self._m_clusters // self._world_size - ) % self._m_clusters - self.n_idx = (linear_idx % group_repeat) // self._group_size - elif linear_idx < remote_m_clusters * self._n_clusters: - remainder_idx: T.let = linear_idx - group_rows * self._n_clusters - self.m_idx = ( - group_rows - + remainder_idx % final_rows - + (my_rank + 1) * self._m_clusters // self._world_size - ) % self._m_clusters - self.n_idx = remainder_idx // final_rows - else: - remainder_idx: T.let = linear_idx - remote_m_clusters * self._n_clusters - self.m_idx = ( - remote_m_clusters - + remainder_idx % (self._m_clusters // self._world_size) - + (my_rank + 1) * self._m_clusters // self._world_size - ) % self._m_clusters - self.n_idx = remainder_idx // (self._m_clusters // self._world_size) - - @T.inline - def next_tile(self, stride: int): - self.linear_idx = self.linear_idx + stride - self.update_current_m_n_idx(self.linear_idx) - - def valid(self): - return self.linear_idx < self._m_clusters * self._n_clusters - - -class IndexedTripleTileScheduler(BaseTileScheduler): - """Scheduler that maps linear_idx to (b_idx, h_idx, q_idx) via index lists.""" - - def __init__(self, prefix: str, b_indices, h_indices, q_indices, tiles_indptr): - super().__init__(prefix) - self.b_indices = b_indices - self.h_indices = h_indices - self.q_indices = q_indices - self.tiles_indptr = tiles_indptr - self.q_idx = T.local_scalar("int32") - self.h_idx = T.local_scalar("int32") - self.b_idx = T.local_scalar("int32") - self.linear_lim = T.local_scalar("int32") - - @T.inline - def _load(self): - self.q_idx = self.q_indices[self.linear_idx] - self.h_idx = self.h_indices[self.linear_idx] - self.b_idx = self.b_indices[self.linear_idx] - - @T.inline - def init(self, sm): - self.linear_idx = self.tiles_indptr[sm] - self.linear_lim = self.tiles_indptr[sm + 1] - self._load() - - @T.inline - def next_tile(self): - self.linear_idx = self.linear_idx + 1 - self._load() - - def valid(self): - return self.linear_idx < self.linear_lim - - -class FlashAttentionLinearScheduler(BaseTileScheduler): - """Linear 3D scheduler for flash attention (batch, head, m_block). - - Used for non-causal attention with simple linear decomposition. - Maps linear_idx -> (batch_idx, head_idx, m_block_idx) using: - batch = linear_idx // (num_heads * num_m_blocks) - head = (linear_idx % (num_heads * num_m_blocks)) // num_m_blocks - m_block = linear_idx % num_m_blocks - - Parameters - ---------- - prefix : str - Prefix for TIR variable names - num_batches : int - Number of batches - num_heads : int - Number of KV heads - num_m_blocks : int - Number of Q blocks (M dimension tiles) - num_ctas : int - Number of CTAs for persistent kernel stride - """ - - def __init__( - self, prefix: str, num_batches: int, num_heads: int, num_m_blocks: int, num_ctas: int - ): - super().__init__(prefix) - self._num_batches = num_batches - self._num_heads = num_heads - self._num_m_blocks = num_m_blocks - self._num_ctas = num_ctas - self._total_tasks = num_batches * num_heads * num_m_blocks - - # Output indices - self.batch_idx = T.local_scalar("int32") - self.head_idx = T.local_scalar("int32") - self.m_block_idx = T.local_scalar("int32") - - # fmt: off - @T.inline - def update_current_m_n_idx(self, linear_idx): - """Convert linear index to (batch, head, m_block) coordinates.""" - NUM_HEADS = T.meta_var(self._num_heads) - NUM_M_BLOCKS = T.meta_var(self._num_m_blocks) - HEAD_M_PRODUCT = T.meta_var(NUM_HEADS * NUM_M_BLOCKS) - - self.batch_idx = linear_idx // HEAD_M_PRODUCT - self.head_idx = (linear_idx % HEAD_M_PRODUCT) // NUM_M_BLOCKS - self.m_block_idx = linear_idx % NUM_M_BLOCKS - - @T.inline - def init(self, cta_id): - """Initialize scheduler with CTA ID.""" - self.linear_idx = cta_id - self.update_current_m_n_idx(cta_id) - - @T.inline - def next_tile(self): - """Advance to next tile by striding by num_ctas.""" - self.linear_idx = self.linear_idx + self._num_ctas - self.update_current_m_n_idx(self.linear_idx) - # fmt: on - - def valid(self): - """Check if there are more tiles to process.""" - return self.linear_idx < self._total_tasks - - -class FlashAttentionLPTScheduler(BaseTileScheduler): - """LPT scheduler with L2 swizzle for causal flash attention. - - Processes high-work Q blocks (with more KV blocks to attend to) first using - Longest Processing Time (LPT) scheduling. Also applies L2 cache swizzle - for better cache locality across batch*head dimensions. - - The LPT aspect comes from reversing m_block order: lower Q blocks have more - KV blocks to process due to causal masking, so processing them first balances load. - - The scheduler is only applied to non-persistent kernels. - - L2 Swizzle: Groups consecutive batch*head indices together for L2 locality. - - Parameters - ---------- - prefix : str - Prefix for TIR variable names - num_batches : int - Number of batches - num_heads : int - Number of KV heads - num_m_blocks : int - Number of Q blocks (M dimension tiles) - num_ctas : int - Number of CTAs (should equal total_tasks for causal) - l2_swizzle : int - L2 swizzle factor for cache locality - """ - - def __init__( - self, prefix: str, num_batches: int, num_heads: int, num_m_blocks: int, l2_swizzle: int - ): - super().__init__(prefix) - self._num_batches = num_batches - self._num_heads = num_heads - self._num_m_blocks = num_m_blocks - self._l2_swizzle = l2_swizzle - self._total_tasks = num_batches * num_heads * num_m_blocks - - # Derived constants for L2 swizzle - self._num_hb = num_batches * num_heads - self._l2_major = l2_swizzle * num_m_blocks - self._num_hb_quotient = self._num_hb // l2_swizzle - - # Output indices - self.batch_idx = T.local_scalar("int32") - self.head_idx = T.local_scalar("int32") - self.m_block_idx = T.local_scalar("int32") - - # fmt: off - @T.inline - def update_current_m_n_idx(self, linear_idx): - """Convert linear index to (batch, head, m_block) with LPT + L2 swizzle.""" - L2_SWIZZLE = T.meta_var(self._l2_swizzle) - L2_MAJOR = T.meta_var(self._l2_major) - NUM_HB_QUOTIENT = T.meta_var(self._num_hb_quotient) - NUM_HB = T.meta_var(self._num_hb) - NUM_HEADS = T.meta_var(self._num_heads) - NUM_M_BLOCKS = T.meta_var(self._num_m_blocks) - - # L2 swizzle decomposition - bidhb: T.let = linear_idx // L2_MAJOR - l2_mod: T.let = linear_idx % L2_MAJOR - - # Handle residual section (last partial swizzle group) - num_hb_remainder: T.let = T.max(NUM_HB % L2_SWIZZLE, 1) - m_block_raw: T.let = T.Select(bidhb < NUM_HB_QUOTIENT, l2_mod // L2_SWIZZLE, l2_mod // num_hb_remainder) # noqa: E501 - bidhb_residual: T.let = T.Select(bidhb < NUM_HB_QUOTIENT, l2_mod % L2_SWIZZLE, l2_mod % num_hb_remainder) # noqa: E501 - bidhb_actual: T.let = bidhb * L2_SWIZZLE + bidhb_residual - - self.batch_idx = bidhb_actual // NUM_HEADS - self.head_idx = bidhb_actual % NUM_HEADS - - # LPT: Reverse block order so high-work blocks are processed first - self.m_block_idx = (NUM_M_BLOCKS - 1) - m_block_raw - - @T.inline - def init(self, cta_id): - """Initialize scheduler with CTA ID.""" - self.linear_idx = cta_id - self.update_current_m_n_idx(cta_id) - - @T.inline - def next_tile(self): - """Advance to next tile by striding by num_ctas.""" - self.linear_idx = self._total_tasks - # fmt: on - - def valid(self): - """Check if there are more tiles to process.""" - return self.linear_idx < self._total_tasks +from tvm.backend.cuda.lang.tile_scheduler import * # noqa: F403 # pylint: disable=wildcard-import diff --git a/python/tvm/tirx/lang/warp_role.py b/python/tvm/tirx/lang/warp_role.py index 0258013bab1a..bb5927fded19 100644 --- a/python/tvm/tirx/lang/warp_role.py +++ b/python/tvm/tirx/lang/warp_role.py @@ -2,143 +2,7 @@ # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Warp role helpers for SM100 kernels. +# to you under the Apache License, Version 2.0. +"""Compatibility redirect for CUDA warp role helpers.""" -Simplifies the common pattern of dispatching warps to named roles -with register budgets. - -Example:: - - # Declare roles - tma_warp = WarpRole(warp_id, 1, regs=48) - store_warp = WarpRole(warp_id, 2, regs=48) - mma_warp = WarpRole(warp_id, 0, regs=232, increase=True) - - # Use with context manager - with tma_warp: - # TMA load code - with store_warp: - # TMA store code - with mma_warp: - # MMA compute code -""" - -from tvm.script import tirx as T - - -class WarpRole: - """A warp-level role that guards a block of code by warp_id comparison - with optional register budget. - - Generates:: - - if == : - T.ptx.setmaxnreg(, ) # if regs specified - - - The ``if`` guard narrows the active set to the single warp; individual - tile-primitive calls inside ```` carry their own exec scope via - a scope-namespace prefix (e.g. ``Tx.warp.copy(...)``). - - Parameters - ---------- - warp_id_var : Var - The warp_id variable (from ``T.warp_id(...)``). - warp_id_val : int - Which warp index this role corresponds to. - regs : int, optional - Register budget (passed to ``T.ptx.setmaxnreg``). - If None, no setmaxnreg is emitted. - increase : bool - Direction for ``setmaxnreg`` (default False = decrease). - """ - - def __init__(self, warp_id_var, warp_id_val, regs=None, increase=False): - self.warp_id_var = warp_id_var - self.warp_id_val = warp_id_val - self.regs = regs - self.increase = increase - - def __enter__(self): - self._if_frame = T.If(self.warp_id_var == self.warp_id_val) - self._if_frame.__enter__() - self._then_frame = T.Then() - self._then_frame.__enter__() - if self.regs is not None: - T.evaluate(T.ptx.setmaxnreg(self.increase, self.regs)) - return self - - def __exit__(self, *exc): - self._then_frame.__exit__(*exc) - self._if_frame.__exit__(*exc) - return False - - -class WarpgroupRole: - """A warpgroup-level role that guards by wg_id comparison, - with optional register budget. - - Generates (single wg_id):: - - if == : - T.ptx.setmaxnreg(, ) # if regs specified - - - Generates (range of wg_ids, e.g. ``wg_id_val=(0, 2)``):: - - if 0 <= and < 2: - T.ptx.setmaxnreg(, ) - - - The ``if`` guard narrows the active set to the target warpgroup(s); - individual tile-primitive calls inside ```` carry their own exec - scope via a scope-namespace prefix (e.g. ``Tx.wg.copy(...)``). - - Parameters - ---------- - wg_id_var : Var - The warpgroup_id variable (from ``T.warpgroup_id(...)``). - wg_id_val : int or tuple[int, int] - Which warpgroup index (int) or range ``(start, stop)`` this role - corresponds to. - regs : int, optional - Register budget. - increase : bool - Direction for ``setmaxnreg`` (default False = decrease). - """ - - def __init__(self, wg_id_var, wg_id_val, regs=None, increase=False): - self.wg_id_var = wg_id_var - self.wg_id_val = wg_id_val - self.regs = regs - self.increase = increase - - def __enter__(self): - if isinstance(self.wg_id_val, tuple): - start, stop = self.wg_id_val - self._if_frame = T.If(start <= self.wg_id_var and self.wg_id_var < stop) - else: - self._if_frame = T.If(self.wg_id_var == self.wg_id_val) - self._if_frame.__enter__() - self._then_frame = T.Then() - self._then_frame.__enter__() - if self.regs is not None: - T.evaluate(T.ptx.setmaxnreg(self.increase, self.regs)) - return self - - def __exit__(self, *exc): - self._then_frame.__exit__(*exc) - self._if_frame.__exit__(*exc) - return False +from tvm.backend.cuda.lang.warp_role import * # noqa: F403 # pylint: disable=wildcard-import diff --git a/python/tvm/tirx/op.py b/python/tvm/tirx/op.py index 4d80ac378e5f..aeeb68ca2f19 100644 --- a/python/tvm/tirx/op.py +++ b/python/tvm/tirx/op.py @@ -33,30 +33,6 @@ from .buffer import Buffer from .expr import BufferLoad, Call, CommReducer, IntImm, PrimExprWithOp, Var -# Choice / IntAttr value tables — single source of truth in -# tvm.tirx.operator.intrinsics._common. Re-exported here under their -# underscored names so the existing _choice(name, value, _FOO) call sites -# below keep working without changes. -from .operator.intrinsics._common import CLUSTER_BARRIER_SEM as _CLUSTER_BARRIER_SEM -from .operator.intrinsics._common import CP_ASYNC_BULK_CACHE_HINT as _CP_ASYNC_BULK_CACHE_HINT -from .operator.intrinsics._common import CP_ASYNC_BULK_RED_OP as _CP_ASYNC_BULK_RED_OP -from .operator.intrinsics._common import CP_ASYNC_CACHE_HINT as _CP_ASYNC_CACHE_HINT -from .operator.intrinsics._common import CP_ASYNC_FILL_MODE as _CP_ASYNC_FILL_MODE -from .operator.intrinsics._common import CP_ASYNC_PREFETCH_SIZE as _CP_ASYNC_PREFETCH_SIZE -from .operator.intrinsics._common import F32X2_ROUND as _F32X2_ROUND -from .operator.intrinsics._common import FENCE_PROXY_ASYNC_SPACE as _FENCE_PROXY_ASYNC_SPACE -from .operator.intrinsics._common import FENCE_SCOPE as _FENCE_SCOPE -from .operator.intrinsics._common import FENCE_SEM as _FENCE_SEM -from .operator.intrinsics._common import LDMATRIX_DTYPE as _LDMATRIX_DTYPE -from .operator.intrinsics._common import LDMATRIX_NUM as _LDMATRIX_NUM -from .operator.intrinsics._common import NVSHMEM_CMP as _NVSHMEM_CMP -from .operator.intrinsics._common import NVSHMEM_SIG_OP as _NVSHMEM_SIG_OP -from .operator.intrinsics._common import TCGEN05_CP_DECOMPRESS as _TCGEN05_CP_DECOMPRESS -from .operator.intrinsics._common import TCGEN05_CP_MULTICAST as _TCGEN05_CP_MULTICAST -from .operator.intrinsics._common import TCGEN05_CP_SHAPES as _TCGEN05_CP_SHAPES -from .operator.intrinsics._common import TCGEN05_CTA_GROUP as _TCGEN05_CTA_GROUP -from .operator.intrinsics._common import TCGEN05_LDST_SHAPES as _TCGEN05_LDST_SHAPES - tir = tirx # alias for backward compat with upstream tir.convert() calls _DEVICE_INTRIN_PREFIX_TO_NAMESPACE = { @@ -931,200 +907,14 @@ def tvm_throw_last_error(): return call_intrin("handle", "tirx.tvm_throw_last_error") -def make_filled_simdgroup_matrix( - d: Var, - index: PrimExpr, - value: PrimExpr, - col: int = 8, - row: int = 8, -): - """Create a filled SIMDGroup matrix - - Parameters - ---------- - d : var - The simdgroup var - - index : PrimExpr - The index of the matrix. - - value : PrimExpr - The value to fill. - - col : int - The number of columns. - - row : int - The number of rows. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("handle", "tirx.make_filled_simdgroup_matrix", d, index, value, col, row) - - -def simdgroup_load( - d: Var, - index: PrimExpr, - ptr: PrimExpr, - stride: PrimExpr, - col: int = 8, - row: int = 8, - transpose_matrix: bool = False, -): - """Load data from device memory or threadgroup memory to simdgroup - - Parameters - ---------- - d : var - The simdgroup var - - index : PrimExpr - The index of the matrix. - - ptr : PrimExpr - The pointer. - - stride : PrimExpr - The stride. - - col : int - The number of columns. - - row : int - The number of rows. - - transpose_matrix : bool - Whether to transpose the matrix. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin( - "handle", - "tirx.simdgroup_load", - d, - index, - ptr, - stride, - col, - row, - transpose_matrix, - ) - - -def simdgroup_store( - d: PrimExpr, - index: PrimExpr, - ptr: PrimExpr, - stride: PrimExpr, - col: int = 8, - row: int = 8, - transpose_matrix: bool = False, -): - """Store data from simdgroup to device memory or threadgroup memory - - Parameters - ---------- - d : PrimExpr - The SIMDGroup. - - index : PrimExpr - The index of the matrix. - - ptr : PrimExpr - The pointer. - - stride : PrimExpr - The stride. - - col : int - The number of columns. - - row : int - The number of rows. - - - transpose_matrix : bool - Whether to transpose the matrix. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin( - "handle", - "tirx.simdgroup_store", - d, - index, - ptr, - stride, - col, - row, - transpose_matrix, - ) - - -def simdgroup_multiply_accumulate( - d: Var, - index_d: PrimExpr, - a: Var, - index_a: PrimExpr, - b: Var, - index_b: PrimExpr, - c: Var, - index_c: PrimExpr, -): - """Multiply and accumulate two matrices in simdgroup - i.e. d = a * b + c - - Parameters - ---------- - d : Var - The destination matrix. - - index_d : PrimExpr - The index of the destination matrix. - - a : Var - The first matrix. - - index_a : PrimExpr - The index of the first matrix. - - b : Var - The second matrix. - - index_b : PrimExpr - The index of the second matrix. - - c : Var - The third matrix. - - index_c : PrimExpr - The index of the third matrix. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin( - "handle", - "tirx.simdgroup_multiply_accumulate", - d, - index_d, - a, - index_a, - b, - index_b, - c, - index_c, +def print_buffer(buffer_var, dtype, is_string, is_scalar, dim_num, *shape): + """Print out buffer memory during runtime.""" + if len(shape) == 1 and isinstance(shape[0], tuple | list | tvm.ir.Array): + final_shape_args = list(shape[0]) + else: + final_shape_args = list(shape) + return _ffi_api.print_buffer( + buffer_var, dtype, is_string, is_scalar, dim_num, *final_shape_args ) @@ -3159,358 +2949,178 @@ def ignore_loop_partition(predicate) -> PrimExpr: max = comm_reducer(lambda x, y: _ffi_api._OpMax(x, y, None), min_value, name="max") # type: ignore -######################################################## -# CUDA native builtins -######################################################## - - -def cuda_func_call(func_name, *args, source_code, return_type="void"): - """TVM intrinsic to call a CUDA function. Source code is provided as a string. +def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): + """TVM intrinsic for tensor core load operators Parameters ---------- - func_name: str - The name of the CUDA function. - - args: PrimExpr - The arguments to the CUDA function. - - source_code: str - The source code of the CUDA function. + fragment : Var + The wmma fragment. - return_type: str - The return type of the CUDA function. - """ - return call_intrin(return_type, "tirx.cuda_func_call", func_name, *args, source_code) + m : UIntImm + The shape of wmma fragment. + n : UIntImm + The shape of wmma fragment. -def cuda_warp_reduce(value, op, width=32): - """Warp-level butterfly shuffle-XOR reduction. + k : UIntImm + The shape of wmma fragment. - Reduces ``value`` across ``width`` adjacent lanes using the specified - operation. Codegen emits ``log2(width)`` steps of - ``__shfl_xor_sync(0xFFFFFFFF, val, mask)`` with descending XOR masks. + index : Expr + The fragment index. - Parameters - ---------- - value : PrimExpr - The per-thread scalar value to reduce. + buffer_ptr : Expr + The fragment buffer pointer. - op : str - Reduction operation: ``"sum"``, ``"max"``, or ``"min"``. + stride : Expr + The fragment stride. - width : int - Number of lanes participating in each reduction group. - Must be a power of two in [2, 32]. Defaults to 32 (full warp). + layout : Literal["row_major", "column_major"] + The fragment layout. Returns ------- call : PrimExpr - The reduced value (same dtype as *value*). + The call expression. """ - return call_intrin(value.dtype, "tirx.cuda_warp_reduce", value, op, width) - - -def cuda_warp_sum(value, width=32): - """Convenience wrapper: ``cuda_warp_reduce(value, "sum", width)``.""" - return cuda_warp_reduce(value, "sum", width) - + return call_intrin( + "handle", "tirx.tvm_load_matrix_sync", fragment, m, n, k, index, buffer_ptr, stride, layout + ) -def cuda_warp_max(value, width=32): - """Convenience wrapper: ``cuda_warp_reduce(value, "max", width)``.""" - return cuda_warp_reduce(value, "max", width) +def tvm_mma_sync( + fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c +): + """TVM intrinsic for tensor core mma_sync operators -def cuda_warp_min(value, width=32): - """Convenience wrapper: ``cuda_warp_reduce(value, "min", width)``.""" - return cuda_warp_reduce(value, "min", width) + Parameters + ---------- + fragment_d : Var + The wmma fragment_d. + index_d : Expr + The fragment_d index. -def cuda_cta_reduce(value, op, num_warps, scratch): - """CTA-wide reduction via warp shuffle + shared memory. + fragment_a : Var + The wmma fragment_a. - Two-step reduction: (1) intra-warp shuffle reduction, (2) warp-0 - collects per-warp partials from ``scratch``, reduces, broadcasts via - ``__syncthreads()``. All CTA threads must participate. + index_a : Expr + The fragment_a index. - Parameters - ---------- - value : PrimExpr - Per-thread scalar value to reduce. + fragment_b : Var + The wmma fragment_b. - op : str - Reduction operation: ``"sum"``, ``"max"``, or ``"min"``. + index_b : Expr + The fragment_b index. - num_warps : int - Number of warps in the CTA. Must be a power of two in [1, 32]. + fragment_c : Var + The wmma fragment_c. - scratch : Var - Data pointer to shared-memory scratch space (>= num_warps elements). + index_c : Expr + The fragment_c index. Returns ------- call : PrimExpr - The reduced value broadcast to all threads (same dtype as *value*). + The call expression. """ - return call_intrin(value.dtype, "tirx.cuda_cta_reduce", value, op, num_warps, scratch) - - -def cuda_cta_sum(value, num_warps, scratch): - """Convenience wrapper: ``cuda_cta_reduce(value, "sum", num_warps, scratch)``.""" - return cuda_cta_reduce(value, "sum", num_warps, scratch) + return call_intrin( + "handle", + "tirx.tvm_mma_sync", + fragment_d, + index_d, + fragment_a, + index_a, + fragment_b, + index_b, + fragment_c, + index_c, + ) -def cuda_cta_max(value, num_warps, scratch): - """Convenience wrapper: ``cuda_cta_reduce(value, "max", num_warps, scratch)``.""" - return cuda_cta_reduce(value, "max", num_warps, scratch) +def tvm_bmma_sync( + fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c +): + """TVM intrinsic for tensor core bmma_sync operators + Parameters + ---------- + fragment_d : Var + The bwmma fragment_d. -def cuda_cta_min(value, num_warps, scratch): - """Convenience wrapper: ``cuda_cta_reduce(value, "min", num_warps, scratch)``.""" - return cuda_cta_reduce(value, "min", num_warps, scratch) + index_d : Expr + The fragment_d index. + fragment_a : Var + The bwmma fragment_a. -def cuda_copy_bytes(dst, src, num_bytes): - """Typed load/store copy of ``num_bytes`` bytes. + index_a : Expr + The fragment_a index. - Copies ``num_bytes`` bytes from ``src`` to ``dst`` using a single - typed load/store instruction. Codegen selects the appropriate C++ - vector type (``uint4``, ``uint2``, ``unsigned int``, etc.). + fragment_b : Var + The bwmma fragment_b. - Parameters - ---------- - dst : Var - Destination pointer. + index_b : Expr + The fragment_b index. - src : Var - Source pointer. + fragment_c : Var + The bwmma fragment_c. - num_bytes : int - Number of bytes to copy. Must be one of {1, 2, 4, 8, 16}. + index_c : Expr + The fragment_c index. Returns ------- call : PrimExpr - A void call expression. + The call expression. """ - return call_intrin("void", "tirx.cuda_copy_bytes", dst, src, num_bytes) - - -def cuda_copy_128b(dst, src): - """Convenience wrapper: ``cuda_copy_bytes(dst, src, 16)`` — copies 128 bits.""" - return cuda_copy_bytes(dst, src, 16) - + return call_intrin( + "handle", + "tirx.tvm_bmma_sync", + fragment_d, + index_d, + fragment_a, + index_a, + fragment_b, + index_b, + fragment_c, + index_c, + ) -def cuda_copy_64b(dst, src): - """Convenience wrapper: ``cuda_copy_bytes(dst, src, 8)`` — copies 64 bits.""" - return cuda_copy_bytes(dst, src, 8) +def tvm_fill_fragment(fragment, m, n, k, index, value): + """TVM intrinsic for tensor core fill_fragment operators -def cuda_copy_32b(dst, src): - """Convenience wrapper: ``cuda_copy_bytes(dst, src, 4)`` — copies 32 bits.""" - return cuda_copy_bytes(dst, src, 4) + Parameters + ---------- + fragment : Var + The wmma fragment + m : UIntImm + The shape of wmma fragment. -def cuda_copy_16b(dst, src): - """Convenience wrapper: ``cuda_copy_bytes(dst, src, 2)`` — copies 16 bits.""" - return cuda_copy_bytes(dst, src, 2) + n : UIntImm + The shape of wmma fragment. + k : UIntImm + The shape of wmma fragment. -def cuda_copy_8b(dst, src): - """Convenience wrapper: ``cuda_copy_bytes(dst, src, 1)`` — copies 8 bits.""" - return cuda_copy_bytes(dst, src, 1) + index : Expr + The fragment index. - -def cuda_warp_sync(): - """TVM intrinsic to synchronize threads within the current warp. - - This lowers to a CUDA `__syncwarp()` call. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.cuda_warp_sync") - - -def cuda_cta_sync(): - """TVM intrinsic to call CUDA syncthreads (block-wide barrier) - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.cuda_cta_sync") - - -def cuda_grid_sync(): - """TVM intrinsic to call CUDA grid-wide sync (cooperative groups) - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.cuda_grid_sync") - - -def cuda_cluster_sync(): - """TVM intrinsic to call CUDA cluster-wide barrier sync - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.cuda_cluster_sync") - - -def cuda_thread_rank(): - """TVM intrinsic that returns ``cooperative_groups::thread_rank()`` - for the enclosing CTA -- the linear thread index within the block. - - Useful for building "single thread of CTA" predicates without - referencing user-declared scope_id vars. For example, the idiomatic - mbarrier.init leader predicate is:: - - T.cuda.thread_rank() == 0 - - Returns - ------- - call : PrimExpr - The call expression (``int32``). - """ - return call_intrin("int32", "tirx.cuda_thread_rank") - - -def cuda_half2float(src): - """TVM intrinsic to convert half to float - - Parameters - ---------- - src : PrimExpr - Source pointer. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("float32", "tirx.cuda_half2float", src) - - -def cuda_bfloat162float(src): - """TVM intrinsic to convert bfloat16 to float - - Parameters - ---------- - src : PrimExpr - Source pointer. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("float32", "tirx.cuda_bfloat162float", src) - - -def cuda_float22half2(dst, src): - """TVM intrinsic to convert float2 to half2 with rounding - - Parameters - ---------- - dst : PrimExpr - Destination pointer. - - src : PrimExpr - Source pointer. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.cuda_float22half2", dst, src) - - -def cuda_trap_when_assert_failed(cond): - """TVM intrinsic to trap when assertion failed (cond == false) - - Parameters - ---------- - cond : PrimExpr - Condition to check. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.cuda_trap_when_assert_failed", cond) - - -def cuda_runtime_instr_desc(desc, sf_id): - """TVM intrinsic to update runtime instruction descriptor - - Parameters - ---------- - desc : PrimExpr - Pointer to the descriptor (uint32*). - - sf_id : PrimExpr - The subfragment id. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.cuda_runtime_instr_desc", desc, sf_id) - - -def cuda_half8tofloat8(src_addr, dst_addr): - """TVM intrinsic to convert 8 half2s to 8 float2s - - Parameters - ---------- - src_addr : PrimExpr - Source pointer. - - dst_addr : PrimExpr - Destination pointer. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.cuda_half8tofloat8", src_addr, dst_addr) - - -def cuda_float8tohalf8(src_addr, dst_addr): - """TVM intrinsic to convert 8 float2s to 8 half2s - - Parameters - ---------- - src_addr : PrimExpr - Source pointer. - - dst_addr : PrimExpr - Destination pointer. + value : Expr + The value to be filled in fragment. Returns ------- call : PrimExpr The call expression. """ - return call_intrin("", "tirx.cuda_float8tohalf8", src_addr, dst_addr) + return call_intrin("handle", "tirx.tvm_fill_fragment", fragment, m, n, k, index, value) -def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): - """TVM intrinsic for tensor core load operators +def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): + """TVM intrinsic for tensor core store operators Parameters ---------- @@ -3544,4669 +3154,50 @@ def tvm_load_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): The call expression. """ return call_intrin( - "handle", "tirx.tvm_load_matrix_sync", fragment, m, n, k, index, buffer_ptr, stride, layout + "handle", "tirx.tvm_store_matrix_sync", fragment, m, n, k, index, buffer_ptr, stride, layout ) -def tvm_mma_sync( - fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c -): - """TVM intrinsic for tensor core mma_sync operators - - Parameters - ---------- - fragment_d : Var - The wmma fragment_d. - - index_d : Expr - The fragment_d index. - - fragment_a : Var - The wmma fragment_a. - - index_a : Expr - The fragment_a index. - - fragment_b : Var - The wmma fragment_b. - - index_b : Expr - The fragment_b index. - - fragment_c : Var - The wmma fragment_c. - - index_c : Expr - The fragment_c index. +def thread_return(): + """TVM intrinsic to call thread_return() Returns ------- call : PrimExpr The call expression. """ - return call_intrin( - "handle", - "tirx.tvm_mma_sync", - fragment_d, - index_d, - fragment_a, - index_a, - fragment_b, - index_b, - fragment_c, - index_c, - ) + return call_intrin("", "tirx.thread_return") -def tvm_bmma_sync( - fragment_d, index_d, fragment_a, index_a, fragment_b, index_b, fragment_c, index_c -): - """TVM intrinsic for tensor core bmma_sync operators +def continue_loop(span=None): + """Create a tir intrinsic call to represent continue expression Parameters ---------- - fragment_d : Var - The bwmma fragment_d. - - index_d : Expr - The fragment_d index. - - fragment_a : Var - The bwmma fragment_a. - - index_a : Expr - The fragment_a index. - - fragment_b : Var - The bwmma fragment_b. - - index_b : Expr - The fragment_b index. - - fragment_c : Var - The bwmma fragment_c. - - index_c : Expr - The fragment_c index. + span : Optional[Span] + The location of this operator in the source code. Returns ------- - call : PrimExpr - The call expression. + ret : PrimExpr + The continue expression """ - return call_intrin( - "handle", - "tirx.tvm_bmma_sync", - fragment_d, - index_d, - fragment_a, - index_a, - fragment_b, - index_b, - fragment_c, - index_c, - ) + return _ffi_api.continue_loop(span) -def tvm_fill_fragment(fragment, m, n, k, index, value): - """TVM intrinsic for tensor core fill_fragment operators + +def break_loop(span=None): + """Create a tir intrinsic call to represent break expression Parameters ---------- - fragment : Var - The wmma fragment - - m : UIntImm - The shape of wmma fragment. - - n : UIntImm - The shape of wmma fragment. - - k : UIntImm - The shape of wmma fragment. - - index : Expr - The fragment index. - - value : Expr - The value to be filled in fragment. + span : Optional[Span] + The location of this operator in the source code. Returns ------- - call : PrimExpr - The call expression. + ret : PrimExpr + The break expression """ - return call_intrin("handle", "tirx.tvm_fill_fragment", fragment, m, n, k, index, value) - - -def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout): - """TVM intrinsic for tensor core store operators - - Parameters - ---------- - fragment : Var - The wmma fragment. - m : UIntImm - The shape of wmma fragment. - - n : UIntImm - The shape of wmma fragment. - - k : UIntImm - The shape of wmma fragment. - - index : Expr - The fragment index. - - buffer_ptr : Expr - The fragment buffer pointer. - - stride : Expr - The fragment stride. - - layout : Literal["row_major", "column_major"] - The fragment layout. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin( - "handle", "tirx.tvm_store_matrix_sync", fragment, m, n, k, index, buffer_ptr, stride, layout - ) - - -def ptx_mma_sp( - dtype, - shape, - A_layout, - B_layout, - A_dtype, - B_dtype, - C_dtype, - multiplicand_a, - a_index, - multiplicand_b, - b_index, - accumulator, - c_index, - metadata, - meta_index, - sparse_selector, - saturate, -): - """TVM intrinsic for sparse tensor core ptx instructions - https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-sparse-mma - - Parameters - ---------- - dtype : str - The data type of the result. - - shape : str - The shape of mma fragment. - - A_layout : Literal["row", "col"] - The layout of multiplicand fragment A. - - B_layout : Literal["row", "col"] - The layout of multiplicand fragment B. - - A_dtype : str - The data type of multiplicand fragment A. - - B_dtype : str - The data type of multiplicand fragment B. - - C_dtype : str - The data type of multiplicand fragment C. - - multiplicand_a : Var - The multiplicand fragment A variable. - - a_index : Expr - The index of multiplicand fragment A. - - multiplicand_b : Var - The multiplicand fragment B variable. - - b_index : Expr - The index of multiplicand fragment B. - - accumulator : Var - The accumulator fragment C variable. - - c_index : Expr - The index of accumulator fragment C. - - metadata : Expr - The metadata of operand. - - meta_index : Expr - The metadata index of operand. - - sparse_selector : Expr - The sparse selector indicating the thread that stores the metadata. - - saturate : bool - The optional saturation at the output. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin( - dtype, - "tirx.ptx_mma_sp", - shape, - A_layout, - B_layout, - A_dtype, - B_dtype, - C_dtype, - multiplicand_a, - a_index, - multiplicand_b, - b_index, - accumulator, - c_index, - metadata, - meta_index, - sparse_selector, - saturate, - ) - - -def mma_store(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride): - """TVM intrinsic for storing the result of PTX MMA into a destination pointer - - Parameters - ---------- - dtype : str - The data type of the result. - - m : IntImm - The shape of mma fragment. - - n : IntImm - The shape of mma fragment. - - dst_ptr : Var - The destination pointer variable. - - src_ptr : Var - The source pointer variable. - - src_offset : Expr - The source offset. - - dst_stride : Var - The destination stride. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin(dtype, "tirx.mma_store", m, n, dst_ptr, src_ptr, src_offset, dst_stride) - - -def mma_store_legacy(dtype, m, n, dst_ptr, src_ptr, src_offset, dst_stride): - """mma_store with apache-style signature. - - ``dst_ptr`` is typically a ``tvm_access_ptr`` Call (so the caller can - encode the destination's element dtype + base offset), and - ``src_ptr + src_offset`` is the raw warp accumulator + element offset. - Codegen does ``ptr + offset`` C pointer arithmetic; lower_warp_memory - rewrites src_offset's group component to a thread-local index.""" - return call_intrin( - dtype, - "tirx.mma_store_legacy", - m, - n, - dst_ptr, - src_ptr, - src_offset, - dst_stride, - ) - - -def mma_fill(dtype, local_size, local_ptr, offset): - """TVM intrinsic for zero-initalizing an MMA accumulation registor - - Parameters - ---------- - dtype : str - The data type of the result. - - local_size : IntImm - The number of elements. - - local_ptr : Var - The destination pointer variable. - - offset : Expr - The destination offset. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin(dtype, "tirx.mma_fill", local_size, local_ptr, offset) - - -def mma_fill_legacy(dtype, local_size, local_ptr, offset): - """mma_fill with (ptr_var, offset). Codegen emits ``ptr + offset`` - C pointer arithmetic; lower_warp_memory rewrites the offset's group - component to a thread-local index.""" - return call_intrin(dtype, "tirx.mma_fill_legacy", local_size, local_ptr, offset) - - -def ptx_cp_async_bulk( - dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes, barrier_id -): - """TVM intrinsic for ptx async copy from global to shared memory using cp.async.bulk - https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk - - Parameters - ---------- - dtype : str - The data type of the result. - - shared_ptr : Var - The shared memory pointer variable. - - shared_offset : Expr - The offset of shared memory pointer. - - global_ptr : Var - The global memory pointer variable. - - global_offset : Expr - The offset of global memory pointer. - - bytes : int - The data size to copy. - - barrier_id : int - The ID of the barrier shared memory pointer. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin( - dtype, - "tirx.ptx_cp_async_bulk", - shared_ptr, - shared_offset, - global_ptr, - global_offset, - bytes, - barrier_id, - ) - - -def ptx_cp_async_bulk_shared_to_cluster(dst_ptr, src_ptr, size, mbar): - """PTX cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes - - Asynchronous bulk copy from executing CTA's shared memory to a remote - CTA's shared memory within the same cluster. - - Parameters - ---------- - dst_ptr : PrimExpr - Destination pointer in shared::cluster address space (remote CTA). - - src_ptr : PrimExpr - Source pointer in shared::cta address space (local CTA). - - size : PrimExpr - Number of bytes to copy (must be multiple of 16). - - mbar : PrimExpr - Mbarrier address in shared::cluster space for completion signaling, - usually produced by ``T.ptx.map_shared_rank``. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_cp_async_bulk_shared_to_cluster", dst_ptr, src_ptr, size, mbar) - - -def ptx_cp_async_mbarrier_arrive(barrier_id): - """TVM intrinsic for ptx async copy barrier using cp.async.mbarrier.arrive - https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive - - Parameters - ---------- - barrier_id : int - The ID of the barrier shared memory pointer. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_cp_async_mbarrier_arrive", barrier_id) - - -def ptx_fence(sem: str, scope: str): - """TVM intrinsic for PTX fence instruction. - - Generates: fence.{sem}.{scope}; - - Parameters - ---------- - sem : str - The semantics of the fence. One of "sc", "acq_rel". - scope : str - The scope of the fence. One of "cta", "cluster", "gpu", "sys". - - Returns - ------- - call : PrimExpr - The call expression. - """ - _choice("sem", sem, _FENCE_SEM) - _choice("scope", scope, _FENCE_SCOPE) - return call_intrin("", "tirx.ptx_fence", sem, scope) - - -def ptx_fence_proxy_async(space: str = ""): - """TVM intrinsic for PTX fence.proxy.async instruction. - - Generates: fence.proxy.async[.{space}]; - - Parameters - ---------- - space : str - The address space qualifier. One of "", "global", "shared::cta", "shared::cluster". - Empty string means no qualifier. - - Returns - ------- - call : PrimExpr - The call expression. - """ - _choice("space", space, _FENCE_PROXY_ASYNC_SPACE) - return call_intrin("", "tirx.ptx_fence_proxy_async", space) - - -def ptx_mbarrier_init(bar, thread_count): - """TVM intrinsic to call mbarrier.init.shared::cta.b64 - - Parameters - ---------- - bar : Var - The pointer to barrier variable. - - thread_count : int - The number of threads expected to arrive at the barrier. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_mbarrier_init", bar, thread_count) - - -def ptx_mbarrier_arrive(bar, cta_id=None, pred=None): - """TVM intrinsic to call - mbarrier.arrive.shared::cta.b64 - or - @p mapa.shared::cluster.u32 - @p mbarrier.arrive.shared::cluster.b64 - - Parameters - ---------- - bar : Var - The pointer to barrier variable. - - cta_id : Optional[PrimExpr] - The cta id. - - pred : Optional[PrimExpr] - The predicate to guard the operation. - """ - if cta_id is None and pred is None: - return call_intrin("", "tirx.ptx_mbarrier_arrive", bar) - assert cta_id is not None and pred is not None - return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred) - - -def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None): - """TVM intrinsic to call - mbarrier.arrive_expect_tx.shared::cta.b64 - or - @p mapa.shared::cluster.u32 - @p mbarrier.arrive_expect_tx.shared::cluster.b64 - - Parameters - ---------- - bar : Var - The pointer to barrier variable. - - byte_count : int - Increases the tx count of the mbarrier object to track completion of - addtional async transactions. - - cta_id : Optional[PrimExpr] - The cta id. - - pred : Optional[PrimExpr] - The predicate to guard the operation. - - Returns - ------- - call : PrimExpr - The call expression. - """ - if cta_id is None and pred is None: - return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, byte_count) - assert cta_id is not None and pred is not None - return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar, byte_count, cta_id, pred) - - -def ptx_mbarrier_try_wait(bar, phase): - """TVM intrinsic to call mbarrier.try_wait.parity repeatedly until it returns true - - Parameters - ---------- - bar : Var - The pointer to barrier variable. - - phase : int - The phase of the barrier. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_mbarrier_try_wait", bar, phase) - - -def ptx_mbarrier_try_wait_once(bar, phase, ticks): - """TVM intrinsic for one-shot non-blocking ``mbarrier.try_wait.parity``. - - Returns ``1`` if the requested parity has been reached and ``0`` otherwise. - This is intended for bounded debug waits; production waits should use - :func:`ptx_mbarrier_try_wait`. - """ - return call_intrin("uint32", "tirx.ptx_mbarrier_try_wait_once", bar, phase, ticks) - - -def ptx_bar_arrive(name_bar_id, thread_count): - """TVM intrinsic to call bar.arrive a, b - - Parameters - ---------- - name_bar_id : int - The ID of the named barrier. - - thread_count : int - The number of threads expected to arrive at the barrier. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_bar_arrive", name_bar_id, thread_count) - - -def ptx_bar_sync(name_bar_id, thread_count): - """TVM intrinsic to call bar.sync a, {b} - - Parameters - ---------- - name_bar_id : int - The ID of the named barrier. - - thread_count : int - The number of threads expected to arrive at the barrier. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_bar_sync", name_bar_id, thread_count) - - -def ptx_cp_async( - dst_ptr, - src_ptr, - cp_size, - *, - cache_hint="", - cache_policy=None, - prefetch_size=-1, - predicate=-1, - fill_mode="", -): - """TVM intrinsic for ptx async copy from global to shared memory using cp.async - https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async - - Dispatches to one of three PTX-form-aligned ops: - - * ``ptx_cp_async_src_size`` for ``fill_mode == "zero"`` (zero-fill via - ``src_size = pred ? cp_size : 0``). - * ``ptx_cp_async_ignore_src`` for a non-empty ``predicate`` with no - fill_mode (``setp+@p`` guards the asm). - * ``ptx_cp_async_plain`` for the no-predicate / no-fill_mode case. - - Parameters - ---------- - shared_ptr : PrimExpr - The pointer to the shared memory. - - global_ptr : PrimExpr - The pointer to the global memory. - - cp_size : int - The data size to copy. - - cache_hint : str["evict_last", "evict_first", "evict_normal", ""] - The cache hint. - - prefetch_size : int[-1, 64, 128, 256] - The prefetch size. - - predicate : PrimExpr - The predicate to guard the operation. - - fill_mode : str["zero", ""] - The fill mode. - - Returns - ------- - call : PrimExpr - The call expression. - """ - cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) - _choice("prefetch_size", prefetch_size, _CP_ASYNC_PREFETCH_SIZE) - _choice("fill_mode", fill_mode, _CP_ASYNC_FILL_MODE) - return call_intrin( - "", - "tirx.ptx_cp_async", - dst_ptr, - src_ptr, - cp_size, - cache_policy, - int(has_cache_policy), - prefetch_size, - predicate, - fill_mode, - ) - - -def ptx_cp_async_legacy(*all_args): - """Legacy ``ptx_cp_async`` API taking explicit src/dst offsets. - - Signature: ``(dst_ptr, dst_offset, src_ptr, src_offset, cp_size)``. - Offsets are folded into the pointers via ``tvm_access_ptr`` then - dispatched to fork-native :func:`ptx_cp_async`. - - ``T.ptx.cp_async_legacy`` runs through ``_dtype_forward`` which - prepends a ``dtype=`` kwarg as a leading positional. The dtype names - the *element* type of the buffer (offsets are in elements of that - dtype, not bytes), so this function accepts either 5 or 6 positional - args. - """ - args = list(all_args) - elem_dtype = "int8" - if len(args) == 6: - # Leading positional is the buffer element dtype, used to scale - # offsets correctly when folding via ``tvm_access_ptr``. - elem_dtype = args.pop(0) - if len(args) != 5: - raise ValueError( - f"ptx_cp_async_legacy expects 5 args (or 6 with dtype= kwarg " - f"prepended); got {len(all_args)}" - ) - dst_ptr, dst_offset, src_ptr, src_offset, cp_size = args - dst_ptr = tvm_access_ptr(elem_dtype, dst_ptr, dst_offset, 1, 1) - src_ptr = tvm_access_ptr(elem_dtype, src_ptr, src_offset, 1, 1) - return ptx_cp_async(dst_ptr, src_ptr, cp_size) - - -def ptx_cp_async_commit_group(): - """TVM intrinsic for ptx async copy commit - https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_cp_async_commit_group") - - -def ptx_cp_async_wait_group(num=0): - """TVM intrinsic for ptx async copy wait - https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group - - Parameters - ---------- - num : int, optional - The number of the most recent uncommitted pending cp.async groups to wait. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_cp_async_wait_group", num) - - -def ptx_cp_async_bulk_tensor_global_to_cluster( - dim, dst_ptr, bar, tensormap_addr, cta_mask, cta_group, cache_hint, *coords, cache_policy=None -): - """TVM intrinsic to call cp.async.bulk.tensor.dim.shared::cluster.global.tile.mbarrier::complete_tx::bytes - - Parameters - ---------- - dim : int - The dimension of the source tensor. - - dst_ptr : PrimExpr - The destination pointer to the shared memory. - - bar : PrimExpr - The pointer to mbarrier variable. - - tensormap_addr : PrimExpr - The generic address of the tensor map object. - - cta_mask : int - The mask of the cta for multicast. - - cta_group : int - Must be either 1 or 2. - If set to 1, mbarrier must be in the shared memory of the same CTA as the shared memory destination - If set to 2, mbarrier can be in shared memory of either the same CTA as the shared memory destination - or the shared memory of the peer CTA. - - cache_hint : str - The cache hint. - - coords : List[PrimExpr] - specifies the starting coordinates in the tensor data in the global memory - - Returns - ------- - call : PrimExpr - The call expression. - """ # noqa: E501 - _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - if isinstance(cache_hint, PrimExpr): - has_cache_policy, *coords = coords - return call_intrin( - "", - "tirx.ptx_cp_async_bulk_tensor_global_to_cluster", - dim, - dst_ptr, - bar, - tensormap_addr, - cta_mask, - cta_group, - cache_hint, - has_cache_policy, - *coords, - ) - cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) - return call_intrin( - "", - "tirx.ptx_cp_async_bulk_tensor_global_to_cluster", - dim, - dst_ptr, - bar, - tensormap_addr, - cta_mask, - cta_group, - cache_policy, - int(has_cache_policy), - *coords, - ) - - -def ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster( - dim, dst_ptr, bar, tensormap_addr, cta_mask, cta_group, cache_hint, *coords, cache_policy=None -): - """TVM intrinsic to call - cp.async.bulk.tensor.dim.shared::cluster.global.tile::gather4.mbarrier::complete_tx::bytes - - Parameters - ---------- - dim : int - The dimension of the source tensor. - - dst_ptr : PrimExpr - The destination pointer to the shared memory. - - bar : PrimExpr - The pointer to mbarrier variable. - - tensormap_addr : PrimExpr - The generic address of the tensor map object. - - cta_mask : int - The mask of the cta for multicast. - - cta_group : int - Must be either 1 or 2. - - cache_hint : str - The cache hint. - - coords : List[PrimExpr] - The TMA coordinates followed by the 4 gather row indices. - - Returns - ------- - call : PrimExpr - The call expression. - """ - _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - if isinstance(cache_hint, PrimExpr): - has_cache_policy, *coords = coords - return call_intrin( - "", - "tirx.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster", - dim, - dst_ptr, - bar, - tensormap_addr, - cta_mask, - cta_group, - cache_hint, - has_cache_policy, - *coords, - ) - cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) - return call_intrin( - "", - "tirx.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster", - dim, - dst_ptr, - bar, - tensormap_addr, - cta_mask, - cta_group, - cache_policy, - int(has_cache_policy), - *coords, - ) - - -def ptx_cp_async_bulk_tensor_shared_to_global( - dim, src_ptr, tensormap_addr, cache_hint, *coords, cache_policy=None -): - """TVM intrinsic to call cp.async.bulk.tensor.dim.global.shared::cta.tile.bulk_group - - Parameters - ---------- - dim : int - The dimension of the copy tensor. - - src_ptr : PrimExpr - The source pointer to the shared memory. - - tensormap_addr : PrimExpr - The generic address of the tensor map object. - - cache_hint : str - The cache hint. - - coords : List[PrimExpr] - specifies the starting coordinates in the tensor data in the global memory - - Returns - ------- - call : PrimExpr - The call expression. - """ - if isinstance(cache_hint, PrimExpr): - has_cache_policy, *coords = coords - return call_intrin( - "", - "tirx.ptx_cp_async_bulk_tensor_shared_to_global", - dim, - src_ptr, - tensormap_addr, - cache_hint, - has_cache_policy, - *coords, - ) - cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) - return call_intrin( - "", - "tirx.ptx_cp_async_bulk_tensor_shared_to_global", - dim, - src_ptr, - tensormap_addr, - cache_policy, - int(has_cache_policy), - *coords, - ) - - -def ptx_cp_async_bulk_tensor_global_to_cluster_prefetch( - dim, tensormap_addr, cache_hint, *coords, cache_policy=None -): - """TVM intrinsic to call cp.async.bulk.prefetch.tensor.dim.L2.global.tile - - Parameters - ---------- - dim : int - The dimension of the source tensor. - - tensormap_addr : PrimExpr - The generic address of the tensor map object. - - cache_hint : str - The cache hint. - - coords : List[PrimExpr] - specifies the starting coordinates in the tensor data in the global memory - - Returns - ------- - call : PrimExpr - The call expression. - """ - if isinstance(cache_hint, PrimExpr): - has_cache_policy, *coords = coords - return call_intrin( - "", - "tirx.ptx_cp_async_bulk_tensor_global_to_cluster_prefetch", - dim, - tensormap_addr, - cache_hint, - has_cache_policy, - *coords, - ) - cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) - return call_intrin( - "", - "tirx.ptx_cp_async_bulk_tensor_global_to_cluster_prefetch", - dim, - tensormap_addr, - cache_policy, - int(has_cache_policy), - *coords, - ) - - -def ptx_cp_async_bulk_tensor_shared_to_global_reduce( - dim, src_ptr, tensormap_addr, cache_hint, red_op, *coords, cache_policy=None -): - """TVM intrinsic to call cp.reduce.async.bulk.tensor.dim.dst.src.redOp - - Parameters - ---------- - dim : int - The dimension of the copy tensor. - - src_ptr : PrimExpr - The source pointer to the shared memory. - - tensormap_addr : PrimExpr - The generic address of the tensor map object. - - cache_hint: str - The cache hint. - - red_op: str - The reduction operator. - - coords: List[PrimExpr] - The coordinates of the tensor. - - Returns - ------- - call : PrimExpr - The call expression. - """ - if isinstance(cache_hint, PrimExpr): - has_cache_policy = red_op - red_op, *coords = coords - _choice("red_op", red_op, _CP_ASYNC_BULK_RED_OP) - return call_intrin( - "", - "tirx.ptx_cp_async_bulk_tensor_shared_to_global_reduce", - dim, - src_ptr, - tensormap_addr, - cache_hint, - has_cache_policy, - red_op, - *coords, - ) - cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) - _choice("red_op", red_op, _CP_ASYNC_BULK_RED_OP) - return call_intrin( - "", - "tirx.ptx_cp_async_bulk_tensor_shared_to_global_reduce", - dim, - src_ptr, - tensormap_addr, - cache_policy, - int(has_cache_policy), - red_op, - *coords, - ) - - -def ptx_cp_async_bulk_commit_group(): - """TVM intrinsic to call cp.async.bulk.tensor.commit_group - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_cp_async_bulk_commit_group") - - -def ptx_cp_async_bulk_wait_group(n=0, read=True): - """TVM intrinsic to call cp.async.bulk.tensor.wait_group - - Parameters - ---------- - n : int - The number of the most recent uncommitted pending cp.async groups to wait. - - read : bool - Whether the wait is for read. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_cp_async_bulk_wait_group", n, read) - - -def ptx_barrier_cluster_arrive(sem="", aligned=True): - """TVM intrinsic to call barrier.cluster.arrive{.sem}{.aligned} - - Parameters - ---------- - sem : str - Either release or relaxed or empty string. - - aligned : bool - Whether all threads in the warp must execute the same instruction. - """ - _choice("sem", sem, _CLUSTER_BARRIER_SEM) - return call_intrin("", "tirx.ptx_barrier_cluster_arrive", sem, aligned) - - -def ptx_barrier_cluster_wait(acquire=False, aligned=True): - """TVM intrinsic to call barrier.cluster.wait{.acquire}{.aligned} - - Parameters - ---------- - acquire : bool - The memory synchronization - - aligned : bool - Whether all threads in the warp must execute the same instruction. - """ - return call_intrin("", "tirx.ptx_barrier_cluster_wait", acquire, aligned) - - -def ptx_elect_sync(): - """TVM intrinsic to call elect.sync""" - return call_intrin("uint32", "tirx.ptx_elect_sync") - - -def ptx_fence_mbarrier_init(): - """TVM intrinsic for PTX fence.mbarrier_init.release.cluster instruction. - - Generates: fence.mbarrier_init.release.cluster; - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_fence_mbarrier_init") - - -def ptx_fetch_register(bits, reg_name): - """TVM intrinsic to tvm instrinsics to fetch PTX pre-defined registers - - Parameters - ---------- - bits : int - The number of bits of the register. - - reg_name : str - The name of the register. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("int" + str(bits), "tirx.ptx_fetch_register", bits, reg_name) - - -def ptx_mma( - shape, - a_layout, - b_layout, - d_type, - a_type, - b_type, - c_type, - d_ptrs, - a_ptrs, - b_ptrs, - c_ptrs=None, - saturate=False, - bit_op=None, -): - """TVM intrinsic for ptx tensor core mma instructions. - https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-for-mma - - Each per-thread register of every operand is addressed by its OWN pointer - (one ``void*`` per b32/f32 register), so the register fragments need not be - contiguous in the register file. ``d_ptrs`` / ``a_ptrs`` / ``b_ptrs`` / - ``c_ptrs`` are lists of one pointer per 32-bit register (b32 for - fp16/bf16/tf32/int8 multiplicands, f32/f64 for the accumulator), enumerated - in the fixed PTX register order (see the gemm dispatch / - ``tests/python/tirx-base/test_tir_ptx_mma.py``). - - Within one b32 register the packed elements (e.g. 2 fp16 along k_pack) - must stay contiguous (stride 1); only the b32 registers themselves may be - scattered. - - Parameters - ---------- - shape : str - The shape of mma fragment. - - a_layout : Literal["row", "col"] - The layout of multiplicand fragment A. - - b_layout : Literal["row", "col"] - The layout of multiplicand fragment B. - - d_type : str - The data type of result fragment D. - - a_type : str - The data type of multiplicand fragment A. - - b_type : str - The data type of multiplicand fragment B. - - c_type : str - The data type of accumulator fragment C. - - d_ptrs : List[PrimExpr] - One pointer per result-fragment D register, in PTX order. - - a_ptrs : List[PrimExpr] - One pointer per multiplicand-A register, in PTX order. - - b_ptrs : List[PrimExpr] - One pointer per multiplicand-B register, in PTX order. - - c_ptrs : Optional[List[PrimExpr]] - One pointer per accumulator-C register, in PTX order. ``None`` (the - default) means the accumulator is not used (beta == 0): codegen feeds - a literal 0 for each C slot. - - saturate : bool - The optional saturation at the output. - - bit_op : Optional[Literal["xor", "and"]] - The 1-bit operator (for the b1 subbyte form). ``None`` means unused. - - Returns - ------- - call : PrimExpr - The call expression. - """ - d_ptrs = list(d_ptrs) - a_ptrs = list(a_ptrs) - b_ptrs = list(b_ptrs) - has_c = c_ptrs is not None - c_ptrs = list(c_ptrs) if has_c else [] - - # Encode group register counts as leading attrs so codegen can slice the - # flat pointer tail. ``no_c_ptr`` mirrors the legacy IntImm(0) sentinel. - no_c_ptr = not has_c - # Flattened pointer list: D regs, A regs, B regs, then C regs (if any). - ptrs = [*d_ptrs, *a_ptrs, *b_ptrs, *c_ptrs] - - base = [ - "", - "tirx.ptx_mma", - shape, - a_layout, - b_layout, - d_type, - a_type, - b_type, - c_type, - len(d_ptrs), - len(a_ptrs), - len(b_ptrs), - len(c_ptrs), - no_c_ptr, - *ptrs, - saturate, - ] - if bit_op is None: - return call_intrin(*base) - return call_intrin(*base, bit_op) - - -def ptx_mma_legacy(*all_args, operator=None): - """Legacy ``ptx_mma`` API. - - Signature: ``(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, - multiplicand_a, a_index, multiplicand_b, b_index, accumulator, - c_index, saturate, operator=None)``. The accumulator is reused as - both input and output (no separate ``d``/``c`` slot), unlike - fork-native :func:`ptx_mma` which distinguishes them. Translation: - - * ``a_dtype, b_dtype, c_dtype`` → fork ``a_type, b_type, c_type`` - (and reuse ``c_dtype`` as fork ``d_type`` since the accumulator - dtype is the output dtype here). - * ``(a_ptr, a_offset)`` and ``(b_ptr, b_offset)`` → folded via - :func:`tvm_access_ptr`. - * ``(accumulator, c_index)`` → folded; passed for both ``d_ptr`` and - ``c_ptr`` since the accumulator is reused as the output. - - ``T.ptx.mma.legacy`` runs through ``_dtype_forward`` which prepends a - ``dtype=`` kwarg as a leading positional, so this function accepts - either 13 or 14 positional args. - """ - args = list(all_args) - # ``T.ptx.mma.legacy(..., dtype="...")`` has the dtype prepended by - # ``_dtype_forward``; strip it here. - if len(args) in (14, 15): - _ = args.pop(0) - if len(args) == 14: - # operator passed positionally as the trailing arg. - operator = args.pop() - if len(args) != 13: - raise ValueError( - f"ptx_mma_legacy expects 13-15 positional args (with optional " - f"leading ``call_dtype`` from dtype= kwarg and optional trailing " - f"``operator``); got {len(all_args)}" - ) - ( - shape, - a_layout, - b_layout, - a_dtype, - b_dtype, - c_dtype, - a_ptr, - a_offset, - b_ptr, - b_offset, - acc_ptr, - c_offset, - saturate, - ) = args - # Emit tirx.ptx_mma_legacy directly with separate (ptr_var, offset) - # pairs. codegen_cuda.cc uses C pointer arithmetic ``ptr + offset`` - # so element offsets stay element-accurate, and lower_warp_memory - # rewrites the offset's group component to a thread-local index. - call_args = [ - shape, - a_layout, - b_layout, - a_dtype, - b_dtype, - c_dtype, - a_ptr, - a_offset, - b_ptr, - b_offset, - acc_ptr, - c_offset, - saturate, - ] - if operator is not None: - call_args.append(operator) - return call_intrin("", "tirx.ptx_mma_legacy", *call_args) - - -def ptx_mma_sp_legacy(*all_args): - """Legacy ``ptx_mma_sp`` API. - - Signature: ``(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, - multiplicand_a, a_index, multiplicand_b, b_index, accumulator, - c_index, metadata, meta_index, sparse_selector, saturate)``. - - ``T.ptx.mma_sp.legacy`` runs through ``_dtype_forward`` which prepends - a ``dtype=`` kwarg as a leading positional, so this function accepts - either 16 or 17 positional args. - """ - args = list(all_args) - if len(args) == 17: - _ = args.pop(0) - if len(args) != 16: - raise ValueError( - f"ptx_mma_sp_legacy expects 16 args (or 17 with dtype= kwarg " - f"prepended); got {len(all_args)}" - ) - ( - shape, - a_layout, - b_layout, - a_dtype, - b_dtype, - c_dtype, - a_ptr, - a_offset, - b_ptr, - b_offset, - acc_ptr, - c_offset, - meta_ptr, - meta_offset, - sparse_selector, - saturate, - ) = args - return ptx_mma_sp( - c_dtype, - shape, - a_layout, - b_layout, - a_dtype, - b_dtype, - c_dtype, - a_ptr, - a_offset, - b_ptr, - b_offset, - acc_ptr, - c_offset, - meta_ptr, - meta_offset, - sparse_selector, - saturate, - ) - - -def ptx_ldmatrix(trans, num, dtype, smem_ptr, *dst_handles): - """TVM intrinsic for ldmatrix.sync.aligned.m8n8.x{num}{.trans}.shared.{dtype}. - - Mirrors the PTX ISA destination form: each output register is a separate - operand. Pass ``T.address_of(buf[idx])`` (or ``buf.ptr_to([idx])``) for - each destination — the slots may be non-contiguous. - - Parameters - ---------- - trans : bool - Apply the ``.trans`` modifier. - num : int - One of 1, 2, 4 — number of m8n8 fragments. - dtype : str - ``"b16"`` (4 bytes per fragment register) or ``"b8"`` (2 bytes per). - smem_ptr : PrimExpr - Generic pointer to source shared memory. - *dst_handles : PrimExpr - N pointer-to-uint32 destinations, where - ``N = num if dtype == "b16" else num // 2``. - - https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix - """ - _choice("num", num, _LDMATRIX_NUM) - _choice("dtype", dtype, _LDMATRIX_DTYPE) - # _LDMATRIX_DTYPE entries carry leading dot (".b16" / ".b8"). - dtype_bare = dtype.lstrip(".") if isinstance(dtype, str) else dtype - n_regs = int(num) if dtype_bare == "b16" else int(num) // 2 - if len(dst_handles) != n_regs: - raise ValueError( - f"ldmatrix .x{int(num)}.{dtype_bare} expects {n_regs} destination " - f"handles, got {len(dst_handles)}" - ) - return call_intrin("", "tirx.ptx_ldmatrix", trans, num, dtype, smem_ptr, *dst_handles) - - -_PTX_TO_NUMPY_DTYPE = { - "fp16": "float16", - "fp32": "float32", - "fp64": "float64", - "bf16": "bfloat16", - "tf32": "float32", - "s8": "int8", - "u8": "uint8", - "s32": "int32", - "s4": "int4", - "u4": "uint4", - "b1": "int1", - "b16": "uint16", - "e4m3": "float8_e4m3fn", - "e5m2": "float8_e5m2", -} - - -def _ptx_to_numpy_dtype(dtype_str): - """Map a PTX-abbreviation or numpy dtype string to a numpy dtype string - suitable for ``tvm_access_ptr`` (which scales the offset by the element - bit width). Unknown strings pass through unchanged so a caller may also - pass an already-numpy dtype.""" - s = dtype_str if isinstance(dtype_str, str) else str(dtype_str) - return _PTX_TO_NUMPY_DTYPE.get(s, s) - - -def _wrap_or_fold_access_ptr(ptr, offset, elem_dtype): - """Wrap ``ptr`` with ``tvm_access_ptr`` unless it already is one. - - Several s_tir tensor intrinsics already pass ``buffer.access_ptr(...)`` - (an ``tvm_access_ptr`` Call) for the pointer argument. Naively wrapping - that again yields a nested ``tvm_access_ptr(... access_ptr(...) ...)`` - whose ``args[1]`` is a Call rather than a Var, which crashes the - lowering rule (Downcast at intrin_rule.cc) and several s_tir - passes that assume a raw buffer var. Detect that case and fold the - outer offset into the inner one. - """ - from tvm.ir import Op # local import to avoid cycles - - is_access_ptr_call = ( - isinstance(ptr, Call) and isinstance(ptr.op, Op) and ptr.op.name == "tirx.tvm_access_ptr" - ) - if is_access_ptr_call: - # Inner Call already wraps the buffer var. Reuse its inner var and - # inner element dtype (the marker type_annotation), and add the - # outer offset (which is in `elem_dtype` units, same convention as - # the inner since both come from the same buffer). - inner_args = ptr.args - inner_marker = inner_args[0] - inner_var = inner_args[1] - inner_offset = inner_args[2] - rw_mask = inner_args[4] - return call_intrin( - "handle", - "tirx.tvm_access_ptr", - inner_marker, - inner_var, - inner_offset + offset, - 1, - rw_mask, - ) - return tvm_access_ptr(elem_dtype, ptr, offset, 1, 1) - - -def ptx_ldmatrix_legacy(*all_args): - """Legacy ``ptx_ldmatrix`` API taking explicit offsets. - - Signature: ``(trans, num, dtype, local_ptr, local_offset, smem_ptr, - smem_offset)``. Offsets are folded into the pointers via - ``tvm_access_ptr`` and dispatched to the fork-native - :func:`ptx_ldmatrix`. - - ``T.ptx.ldmatrix_legacy`` runs through ``_dtype_forward`` which - prepends a ``dtype=`` kwarg as a leading positional naming the buffer - element type — offsets are in elements of that dtype, not bytes, so - we forward it to ``tvm_access_ptr`` for correct scaling. - """ - if len(all_args) == 8: - elem_dtype, trans, num, dtype, local_ptr, local_offset, smem_ptr, smem_offset = all_args - elif len(all_args) == 7: - trans, num, dtype, local_ptr, local_offset, smem_ptr, smem_offset = all_args - elem_dtype = "int8" - else: - raise ValueError( - f"ptx_ldmatrix_legacy expects 7 args (or 8 with dtype= kwarg " - f"prepended); got {len(all_args)}" - ) - # Call.dtype carries the buffer element type so codegen can pick the - # int8+trans manual-loop fallback (ldmatrix can't transpose int8). - return call_intrin( - elem_dtype, - "tirx.ptx_ldmatrix_legacy", - trans, - num, - dtype, - local_ptr, - local_offset, - smem_ptr, - smem_offset, - ) - - -def ptx_stmatrix(trans, num, dtype, smem_ptr, *src_handles, shape="m8n8", space="shared"): - """TVM intrinsic for ``stmatrix.sync.aligned.shape.x{num}{.trans}.space.{dtype}``. - - Mirrors :func:`ptx_ldmatrix`: each source register is a separate operand. - Pass ``T.address_of(buf[idx])`` (or ``buf.ptr_to([idx])``) for each - source — the slots may be non-contiguous. - - Parameters - ---------- - trans : bool - Apply the ``.trans`` modifier (required for ``shape == "m16n8"``). - num : int - One of 1, 2, 4 — number of m8n8 fragments per warp. - dtype : str - ``".b16"`` (4 bytes per fragment register) or ``".b8"`` (2 bytes per). - smem_ptr : PrimExpr - Destination pointer in shared memory. - *src_handles : PrimExpr - ``num`` pointer-to-uint32 sources. - shape : str, keyword-only, default "m8n8" - ``"m8n8"`` or ``"m16n8"``. - space : str, keyword-only, default "shared" - ``"shared"`` or ``"shared::cta"``. - - https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-stmatrix - """ - _choice("num", num, _LDMATRIX_NUM) - _choice("dtype", dtype, _LDMATRIX_DTYPE) - if shape not in ("m8n8", "m16n8"): - raise ValueError(f"Unsupported stmatrix shape {shape!r}") - if space not in ("shared", "shared::cta"): - raise ValueError(f"Unsupported stmatrix state space {space!r}") - if shape == "m16n8" and not trans: - raise ValueError("stmatrix .m16n8 requires .trans") - n_regs = int(num) - if len(src_handles) != n_regs: - dtype_bare = dtype.lstrip(".") if isinstance(dtype, str) else dtype - raise ValueError( - f"stmatrix .x{int(num)}.{dtype_bare} expects {n_regs} source " - f"handles, got {len(src_handles)}" - ) - return call_intrin( - "", "tirx.ptx_stmatrix", trans, num, dtype, shape, space, smem_ptr, *src_handles - ) - - -def ptx_wgmma_encode_matrix_descriptor(desc, addr, ldo, sdo, swizzle): - """TVM intrinsic to create memory descriptor for wgmma instructions - - Parameters - ---------- - desc : PrimExpr - The pointer to the shared memory descriptor. - - addr : PrimExpr - The address of the matrix. - - ldo : PrimExpr - The leading dimension offset. - - sdo : PrimExpr - The stride dimension offset. - - swizzle : int - The swizzle value (CUtensorMapSwizzle_enum). - """ - return call_intrin("", "tirx.ptx_wgmma_encode_matrix_descriptor", desc, addr, ldo, sdo, swizzle) - - -def ptx_wgmma_noop_barrier(reg): - """TVM intrinsic to call "" : "+{format}"(reg)::"memory" - - Parameters - ---------- - reg : PrimExpr - The register to fence. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_wgmma_noop_barrier", reg) - - -def ptx_wgmma_mma_async_ss( - descA, descB, *accums, M, N, K, in_dtype, out_dtype, transA, transB, scaleA, scaleB, scaleD -): - """TVM intrinsic to call wgmma.mma_async.sync.aligned.shape.dtype.atype.btype over 2 smem operators - - Parameters - ---------- - M : int - The number of rows in matrix A and D. - - N : int - The number of columns in matrix B and D. - - K : int - The number of columns in matrix A and rows in matrix B. - - in_dtype : str - The data type of the input matrices. - - out_type : str - The data type of the output matrices. - - transA : bool - True for M/N major, False for K major. - - transB : bool - True for M/N major, False for K major. - - scaleA : float - The scaling factor for matrix A. - - scaleB : float - The scaling factor for matrix B. - - scaleD : PrimExpr - True: D = A * B + D, False: D = A * B. - - descA : PrimExpr - The SMEM descriptor of matrix A - - descB : PrimExpr - The SMEM descriptor of matrix B - - accums : list - The accumulators registers. - """ # noqa: E501 - return call_intrin( - "", - "tirx.ptx_wgmma_mma_async_ss", - M, - N, - K, - in_dtype, - out_dtype, - transA, - transB, - scaleA, - scaleB, - scaleD, - descA, - descB, - *accums, - ) - - -def ptx_wgmma_mma_async_rs( - descB, *reg_list, M, N, K, in_dtype, out_dtype, transA, transB, scaleA, scaleB, scaleD -): - """TVM intrinsic to call wgmma.mma_async.sync.aligned.shape.dtype.atype.btype - When A is in register and B is in shared memory - - Parameters - ---------- - M : int - The number of rows in matrix A and D. - - N : int - The number of columns in matrix B and D. - - K : int - The number of columns in matrix A and rows in matrix B. - - in_dtype : str - The data type of the input matrices. - - out_type : str - The data type of the output matrices. - - transA : bool - True for M/N major, False for K major. - - transB : bool - True for M/N major, False for K major. - - scaleA : float - The scaling factor for matrix A. - - scaleB : float - The scaling factor for matrix B. - - scaleD : PrimExpr - True: D = A * B + D, False: D = A * B. - - descB : PrimExpr - The SMEM descriptor of matrix B - - reg_list : list - The A registers and accumulators registers. - """ - return call_intrin( - "", - "tirx.ptx_wgmma_mma_async_rs", - M, - N, - K, - in_dtype, - out_dtype, - transA, - transB, - scaleA, - scaleB, - scaleD, - descB, - *reg_list, - ) - - -def ptx_wgmma_fence(): - """TVM intrinsic to call wgmma.fence.sync.aligned - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_wgmma_fence") - - -def ptx_wgmma_commit_group(): - """TVM intrinsic to call wgmma.commit_group.sync.aligned - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_wgmma_commit_group") - - -def ptx_wgmma_wait_group(n): - """TVM intrinsic to call wgmma.wait_group.sync.aligned - - Parameters - ---------- - n : int - The number of the most recent uncommitted pending wgmma groups to wait. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_wgmma_wait_group", n) - - -def ptx_setmaxnreg(inc: bool, reg_count): - """TVM intrinsic to call setmaxnreg.action.sync.aligned.u32 imm-reg-count - - Parameters - ---------- - inc : bool - True to increase the register count, False to decrease. - - reg_count : int - The register count. - """ - return call_intrin("", "tirx.ptx_setmaxnreg", inc, reg_count) - - -def ptx_tcgen05_alloc(dst_ptr, n_cols, cta_group=1): - """TVM intrinsic to call tcgen05.alloc.cta_group.sync.aligned - Dynamically allocates the number of cols in tensor memory, and write - the address of allocated memory to shared memory. - - Parameters - ---------- - dst_ptr : Var - The pointer to the destination shared memory. - - n_cols : int - The number of columns to allocate in tensor memory. - Must be a multiple of 32 and a power of 2, and within the range [32, 512]. - - cta_group : int - The number of CTA groups involved in the allocation. - If cta_group=1, one warp from CTA performs the allocation. Else, if cta_group=2, - one warp from each of the peer CTAs perform the allocation. - """ - _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - return call_intrin("", "tirx.ptx_tcgen05_alloc", dst_ptr, n_cols, cta_group) - - -def ptx_tcgen05_dealloc(taddr, n_cols, cta_group=1): - """TVM intrinsic to call tcgen05.dealloc.cta_group.sync.aligned - Deallocates the tensor memory specified by the tensor memory address taddr. - - Parameters - ---------- - taddr : PrimExpr - The address of previously allocated tensor memory, should be uint32_t. - - n_cols : int - The number of columns to deallocate in tensor memory. - Must be a multiple of 32 and a power of 2, and within the range [32, 512]. - - cta_group : int - The number of CTA groups involved in the deallocation. - If cta_group=1, one warp from CTA performs the deallocation. Else, if cta_group=2, - one warp from each of the peer CTAs perform the deallocation. - """ - _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - return call_intrin("", "tirx.ptx_tcgen05_dealloc", taddr, n_cols, cta_group) - - -def ptx_tcgen05_relinquish_alloc_permit(cta_group=1): - """TVM intrinsic to call tcgen05.relinquish_alloc_permit.cta_group.sync.aligned - The CTA of the executing thread is relinquishing the right to allocate - Tensor Memory after calling this op. - - Parameters - ---------- - cta_group : int - The number of CTA groups involved in relinquishing. - If cta_group=1, one warp from CTA performs the relinquishing. Else, if cta_group=2, - one warp from each of the peer CTAs perform the relinquishing. - """ - _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - return call_intrin("", "tirx.ptx_tcgen05_relinquish_alloc_permit", cta_group) - - -def ptx_tcgen05_encode_matrix_descriptor(desc, addr, ldo, sdo, swizzle): - """TVM intrinsic to create memory descriptor for tcgen05 instructions - - Parameters - ---------- - desc : PrimExpr - The pointer to the shared memory descriptor. - - addr : PrimExpr - The address of the matrix. - - ldo : PrimExpr - The leading dimension offset. - - sdo : PrimExpr - The stride dimension offset. - - swizzle : int - The swizzle value (CUtensorMapSwizzle_enum). - """ - return call_intrin( - "", "tirx.ptx_tcgen05_encode_matrix_descriptor", desc, addr, ldo, sdo, swizzle - ) - - -def ptx_tcgen05_encode_instr_descriptor( - desc, - *, - d_dtype, - a_dtype, - b_dtype, - M, - N, - K, - trans_a, - trans_b, - n_cta_groups=1, - neg_a=False, - neg_b=False, - sat_d=False, - is_sparse=False, -): - """TVM intrinsic to create instruction descriptor for tcgen05 MMA without block scaling - - Parameters - ---------- - desc : PrimExpr - The pointer to the instruction descriptor. - - d_dtype : str - The datatype of resultant matrix D. - - a_dtype : str - The datatype of multiplicand matrix A. - - b_dtype : str - The datatype of multiplicand matrix B. - - M : int - The size of non-reduction dimension of Matrix A. - - N : int - The size of non-reduction dimension of Matrix B. - - K : int - The size of reduction dimension of Matrix A/B. - - trans_a : bool - Whether the multiplicand matrix A is transposed. - True for M/N major, False for K major. - - trans_b : bool - Whether the multiplicand matrix B is transposed. - True for M/N major, False for K major. - - n_cta_groups : int - The number of CTA groups involved in the MMA operation. - - neg_a : bool - Whether to negate the multiplicand matrix A. - - neg_b : bool - Whether to negate the multiplicand matrix B. - - sat_d : bool - Whether to saturate the resultant matrix D. - - is_sparse : bool - Whether the MMA operation is sparse. - """ - _choice("n_cta_groups", n_cta_groups, _TCGEN05_CTA_GROUP) - return call_intrin( - "", - "tirx.ptx_tcgen05_encode_instr_descriptor", - desc, - d_dtype, - a_dtype, - b_dtype, - M, - N, - K, - trans_a, - trans_b, - n_cta_groups, - neg_a, - neg_b, - sat_d, - is_sparse, - ) - - -def ptx_tcgen05_encode_instr_descriptor_block_scaled( - desc, - *, - d_dtype, - a_dtype, - b_dtype, - sfa_dtype, - sfb_dtype, - sfa_tmem_addr, - sfb_tmem_addr, - M, - N, - K, - trans_a, - trans_b, - n_cta_groups=1, - neg_a=False, - neg_b=False, - is_sparse=False, -): - """TVM intrinsic to create instruction descriptor for tcgen05 MMA with block scaling - - Parameters - ---------- - desc : PrimExpr - The pointer to the instruction descriptor. - - d_dtype : str - The datatype of resultant matrix D. - - a_dtype : str - The datatype of multiplicand matrix A. - - b_dtype : str - The datatype of multiplicand matrix B. - - sfa_dtype : str - The datatype of scale factor matrix A. - - sfb_dtype : str - The datatype of scale factor matrix B. - - sfa_tmem_addr : PrimExpr - The address of the scale factor matrix A in tensor memory, should be uint32_t. - - sfb_tmem_addr : PrimExpr - The address of the scale factor matrix B in tensor memory, should be uint32_t. - - M : int - The size of non-reduction dimension of Matrix A. - - N : int - The size of non-reduction dimension of Matrix B. - - K : int - The size of reduction dimension of Matrix A/B. - - trans_a : bool - Whether the multiplicand matrix A is transposed. - True for M/N major, False for K major. - - trans_b : bool - Whether the multiplicand matrix B is transposed. - True for M/N major, False for K major. - - n_cta_groups : int - The number of CTA groups involved in the MMA operation. - - neg_a : bool - Whether to negate the multiplicand matrix A. - - neg_b : bool - Whether to negate the multiplicand matrix B. - - is_sparse : bool - Whether the MMA operation is sparse. - """ - _choice("n_cta_groups", n_cta_groups, _TCGEN05_CTA_GROUP) - return call_intrin( - "", - "tirx.ptx_tcgen05_encode_instr_descriptor_block_scaled", - desc, - d_dtype, - a_dtype, - b_dtype, - sfa_dtype, - sfb_dtype, - sfa_tmem_addr, - sfb_tmem_addr, - M, - N, - K, - trans_a, - trans_b, - n_cta_groups, - neg_a, - neg_b, - is_sparse, - ) - - -def ptx_tcgen05_mma( - d_tmem_addr, - a_operand, - b_desc, - i_desc, - *disable_output_lane, - d_dtype, - a_dtype, - b_dtype, - use_a_tmem, - cta_group, - enable_input_d=1, - scale_input_d=0, - pred=None, -): - """TVM intrinsic to call tcgen05.mma.cta_group.kind without block scaling. - - Parameters - ---------- - d_dtype : str - The datatype of resultant matrix D. - - a_dtype : str - The datatype of multiplicand matrix A. - - b_dtype : str - The datatype of multiplicand matrix B. - - d_tmem_addr : PrimExpr - The address of the resultant matrix D in tensor memory, should be uint32_t. - - a_operand : PrimExpr - Either the matrix descriptor of multiplicand matrix A in shared memory, - or the address of the multiplicand matrix A in tensor memory (uint32_t). - - b_desc : PrimExpr - The matrix descriptor of multiplicand matrix B in shared memory. - - i_desc : PrimExpr - The instruction descriptor of the MMA operation. - - use_a_tmem : bool - Whether the multiplicand matrix A is in tensor memory. - - cta_group : int - The number of CTA groups involved in the MMA operation. - - enable_input_d : PrimExpr - Scale operand for the input accumulator C/D. The inline asm tests - `enable_input_d != 0`: zero means D = A*B, non-zero means D = A*B + D. - - scale_input_d : int - The optional scaling factor to scale input matrix D. - D = A*B+D * (2 ^ - scale-input-d) - - disable_output_lane : list - The lanes that should not be updated in the resultant matrix D. - - pred : Optional[PrimExpr] - Runtime ``uint32`` instruction-level predicate. When given, emit - ``@p_issue tcgen05.mma...`` with ``p_issue = (pred != 0)``. Preserves - PTX-level predicate semantics (single predicated SASS instruction). - """ - - _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - - # default value for disable_output_lane - if len(disable_output_lane) == 0: - disable_output_lane = [0] * (4 if cta_group == 1 else 8) - - args = [ - d_dtype, - a_dtype, - b_dtype, - d_tmem_addr, - a_operand, - b_desc, - i_desc, - use_a_tmem, - cta_group, - enable_input_d, - scale_input_d, - *disable_output_lane, - ] - if pred is not None: - args.append(pred) - return call_intrin("", "tirx.ptx_tcgen05_mma", *args) - - -def ptx_tcgen05_mma_block_scale( - d_tmem_addr, - a_operand, - b_desc, - sfa_tmem_addr, - sfb_tmem_addr, - i_desc, - *, - d_dtype, - a_dtype, - b_dtype, - sfa_dtype, - sfb_dtype, - use_a_tmem, - cta_group, - enable_input_d=1, -): - """TVM intrinsic to call tcgen05.mma.cta_group.kind.block_scale - Performs matrix multiplication with block scaling: - (A * scale_A) * (B * scale_B) + D - - Parameters - ---------- - d_dtype : str - The datatype of resultant matrix D. - - a_dtype : str - The datatype of multiplicand matrix A. - - b_dtype : str - The datatype of multiplicand matrix B. - - sfa_dtype : str - The datatype of scale factor matrix A. - - sfb_dtype : str - The datatype of scale factor matrix B. - - d_tmem_addr : PrimExpr - The address of the resultant matrix D in tensor memory, should be uint32_t. - - a_operand : PrimExpr - Either the matrix descriptor of multiplicand matrix A in shared memory, - or the address of the multiplicand matrix A in tensor memory (uint32_t). - - b_desc : PrimExpr - The matrix descriptor of multiplicand matrix B in shared memory. - - sfa_tmem_addr : PrimExpr - The address of the scale factor matrix A in tensor memory, should be uint32_t. - - sfb_tmem_addr : PrimExpr - The address of the scale factor matrix B in tensor memory, should be uint32_t. - - i_desc : PrimExpr - The instruction descriptor of the MMA operation. - - use_a_tmem : bool - Whether the multiplicand matrix A is in tensor memory. - - cta_group : int - The number of CTA groups involved in the MMA operation. - - enable_input_d : PrimExpr - Scale operand for the input accumulator C/D. Zero means D = A*B, - non-zero means D = A*B + D. - """ - - _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - return call_intrin( - "", - "tirx.ptx_tcgen05_mma_block_scale", - d_dtype, - a_dtype, - b_dtype, - sfa_dtype, - sfb_dtype, - d_tmem_addr, - a_operand, - b_desc, - sfa_tmem_addr, - sfb_tmem_addr, - i_desc, - use_a_tmem, - cta_group, - enable_input_d, - ) - - -def ptx_tcgen05_mma_sp( - d_tmem_addr, - a_operand, - b_desc, - sp_tmem_addr, - i_desc, - *disable_output_lane, - d_dtype, - a_dtype, - b_dtype, - use_a_tmem, - cta_group, - enable_input_d=1, - scale_input_d=0, -): - """TVM intrinsic to call tcgen05.mma.sp.cta_group.kind without block scaling. - - Parameters - ---------- - d_dtype : str - The datatype of resultant matrix D. - - a_dtype : str - The datatype of multiplicand matrix A. - - b_dtype : str - The datatype of multiplicand matrix B. - - d_tmem_addr : PrimExpr - The address of the resultant matrix D in tensor memory, should be uint32_t. - - a_operand : PrimExpr - Either the matrix descriptor of multiplicand matrix A in shared memory, - or the address of the multiplicand matrix A in tensor memory (uint32_t). - - b_desc : PrimExpr - The matrix descriptor of multiplicand matrix B in shared memory. - - sp_tmem_addr : PrimExpr - The address of the metadata of sparse matrix in tensor memory, should be uint32_t. - - i_desc : PrimExpr - The instruction descriptor of the MMA operation. - - use_a_tmem : bool - Whether the multiplicand matrix A is in tensor memory. - - cta_group : int - The number of CTA groups involved in the MMA operation. - - enable_input_d : PrimExpr - Scale operand for the input accumulator C/D. The inline asm tests - `enable_input_d != 0`: zero means D = A*B, non-zero means D = A*B + D. - - scale_input_d : int - The optional scaling factor to scale input matrix D. - D = A*B+D * (2 ^ - scale-input-d) - - disable_output_lane : list - The lanes that should not be updated in the resultant matrix D. - """ - - _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - - # default value for disable_output_lane - if len(disable_output_lane) == 0: - disable_output_lane = [0] * (4 if cta_group == 1 else 8) - - return call_intrin( - "", - "tirx.ptx_tcgen05_mma_sp", - d_dtype, - a_dtype, - b_dtype, - d_tmem_addr, - a_operand, - b_desc, - sp_tmem_addr, - i_desc, - use_a_tmem, - cta_group, - enable_input_d, - scale_input_d, - *disable_output_lane, - ) - - -def ptx_tcgen05_mma_sp_block_scale( - d_tmem_addr, - a_operand, - b_desc, - sfa_tmem_addr, - sfb_tmem_addr, - sp_tmem_addr, - i_desc, - *, - d_dtype, - a_dtype, - b_dtype, - sfa_dtype, - sfb_dtype, - use_a_tmem, - cta_group, - enable_input_d=1, -): - """TVM intrinsic to call tcgen05.mma.sp.cta_group.kind.block_scale - Performs sparse matrix multiplication with block scaling: - (A * scale_A) * (B * scale_B) + D - - Parameters - ---------- - d_dtype : str - The datatype of resultant matrix D. - - a_dtype : str - The datatype of multiplicand matrix A. - - b_dtype : str - The datatype of multiplicand matrix B. - - sfa_dtype : str - The datatype of scale factor matrix A. - - sfb_dtype : str - The datatype of scale factor matrix B. - - d_tmem_addr : PrimExpr - The address of the resultant matrix D in tensor memory, should be uint32_t. - - a_operand : PrimExpr - Either the matrix descriptor of multiplicand matrix A in shared memory, - or the address of the multiplicand matrix A in tensor memory (uint32_t). - - b_desc : PrimExpr - The matrix descriptor of multiplicand matrix B in shared memory. - - sfa_tmem_addr : PrimExpr - The address of the scale factor matrix A in tensor memory, should be uint32_t. - - sfb_tmem_addr : PrimExpr - The address of the scale factor matrix B in tensor memory, should be uint32_t. - - sp_tmem_addr : PrimExpr - The address of the metadata of sparse matrix in tensor memory, should be uint32_t. - - i_desc : PrimExpr - The instruction descriptor of the MMA operation. - - use_a_tmem : bool - Whether the multiplicand matrix A is in tensor memory. - - cta_group : int - The number of CTA groups involved in the MMA operation. - - enable_input_d : PrimExpr - Scale operand for the input accumulator C/D. Zero means D = A*B, - non-zero means D = A*B + D. - """ - _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - return call_intrin( - "", - "tirx.ptx_tcgen05_mma_sp_block_scale", - d_dtype, - a_dtype, - b_dtype, - sfa_dtype, - sfb_dtype, - d_tmem_addr, - a_operand, - b_desc, - sfa_tmem_addr, - sfb_tmem_addr, - sp_tmem_addr, - i_desc, - use_a_tmem, - cta_group, - enable_input_d, - ) - - -def ptx_tcgen05_fence_before_thread_sync(): - """TVM intrinsic to call tcgen05.fence::before_thread_sync - Orders all prior asynchronous tcgen05 operations relative to subsequent operations. - """ - return call_intrin("", "tirx.ptx_tcgen05_fence_before_thread_sync") - - -def ptx_tcgen05_fence_after_thread_sync(): - """TVM intrinsic to call tcgen05.fence::after_thread_sync - Orders all subsequent asynchronous tcgen05 operations relative to previous operations. - """ - return call_intrin("", "tirx.ptx_tcgen05_fence_after_thread_sync") - - -def _choice(name: str, value, options): - """Validate `value` is one of `options`. Raise a clear ValueError otherwise. - - Symbolic values (Var, non-constant PrimExpr) are accepted without - validation; specialization later replaces them with concrete values - that the C-side intrinsic body re-checks. - """ - # Concrete int / IntImm value: validate. - try: - concrete = int(value) - except (TypeError, ValueError): - return # symbolic; defer check - if concrete not in options: - raise ValueError(f"invalid {name}={concrete!r}; expected one of {tuple(options)}") - - -# See top-of-file imports for `_FENCE_SEM` etc. (re-exported from _common). -# Note: TCGEN05_LDST_SHAPES values must stay in sync with the shape branches -# of codegen_ptx_tcgen05_ld/_st in intrinsics/cuda/tcgen05.py. - - -def ptx_tcgen05_cp( - taddr, src_desc, *, shape, cta_group=1, multicast="", decompress="", row=0, col=0 -): - """TVM intrinsic for the Blackwell `tcgen05.cp` PTX instruction. - - The emitted PTX is:: - - tcgen05.cp.cta_group::{cta_group}.{shape}[.{multicast}][.{decompress}] [taddr], src_desc; - - Each keyword argument maps 1:1 to a PTX token: read the call and you - know what instruction is emitted. - - Parameters - ---------- - taddr : PrimExpr - Destination tensor-memory address (uint32). Callers typically pass - ``tmem_base + column_offset_in_uint32s`` directly. Use the optional - ``row`` / ``col`` keyword arguments only when the address needs - runtime row/col composition via ``get_tmem_addr`` (high 16 bits row, - low 16 bits col). - - src_desc : PrimExpr - The 64-bit shared-memory matrix descriptor. - - shape : str - One of ``"32x128b"``, ``"4x256b"``, ``"128x128b"``, ``"128x256b"``, - ``"64x128b"``. - - cta_group : int - 1 or 2. - - multicast : str - One of ``""``, ``"warpx4"``, ``"warpx2::02_13"``, ``"warpx2::01_23"``. - ``"32x128b"`` requires ``"warpx4"``; ``"64x128b"`` requires one of the - ``warpx2::*`` values; other shapes require ``""``. - - decompress : str - Trailing PTX suffix for fp4/fp6 → fp8 on-the-fly decompression. - One of ``""``, ``"b8x16.b4x16_p64"``, ``"b8x16.b6x16_p32"``. - - row, col : PrimExpr - Optional row/col offsets added to ``taddr`` at runtime. Default 0. - """ - _choice("shape", shape, _TCGEN05_CP_SHAPES) - _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - _choice("multicast", multicast, _TCGEN05_CP_MULTICAST) - _choice("decompress", decompress, _TCGEN05_CP_DECOMPRESS) - if shape == "32x128b" and multicast != "warpx4": - raise ValueError(f"shape=32x128b requires multicast='warpx4', got {multicast!r}") - if shape == "64x128b" and multicast not in ("warpx2::02_13", "warpx2::01_23"): - raise ValueError(f"shape=64x128b requires multicast in warpx2::*, got {multicast!r}") - if shape in ("128x128b", "128x256b", "4x256b") and multicast != "": - raise ValueError(f"shape={shape} requires multicast='', got {multicast!r}") - - return call_intrin( - "", - "tirx.ptx_tcgen05_cp", - taddr, - src_desc, - shape, - cta_group, - multicast, - decompress, - row, - col, - ) - - -def ptx_tcgen05_shift(taddr, cta_group=1): - """TVM intrinsic to call tcgen05.shift.cta_group.down - Asynchronously shift down the rows of the matrix in Tensor Memory for a warp. - - Parameters - ---------- - taddr : PrimExpr - The address of matrix in tensor memory, should be uint32_t. - - cta_group : int - The number of CTA groups involved in the shift. - If cta_group=1, shift operation is performed in the Tensor Memory of current CTA. - Else, shift operation is performed in the Tensor Memory of both the current CTA and - the peer CTA. - """ - _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - return call_intrin("", "tirx.ptx_tcgen05_shift", taddr, cta_group) - - -def ptx_tcgen05_ld(src_addr, *regs, shape, num, row=0, col=0, pack=False): - """TVM intrinsic for tcgen05.ld.sync.aligned — async collective load from TMEM. - - Emits ``tcgen05.ld.sync.aligned.{shape}.x{num}[.pack::16b].b32 {regs}, [addr];`` - - Parameters - ---------- - src_addr : PrimExpr - Tensor-memory source address (uint32). - - regs : list[PrimExpr] - Destination registers. Count depends on shape x num. - - shape : str - One of ``"16x32bx2"``, ``"16x64b"``, ``"16x128b"``, ``"16x256b"``, ``"32x32b"``. - - num : int - Repeat factor along the columns. Power-of-two in [1, 128]. - - row, col : PrimExpr - Optional TMEM row/col offsets added to ``src_addr`` at runtime (row must be - a multiple of 32). Default 0. - - pack : bool - Pack two 16-bit chunks into a single 32-bit register. - """ - _choice("shape", shape, _TCGEN05_LDST_SHAPES) - return call_intrin("", "tirx.ptx_tcgen05_ld", src_addr, row, col, shape, num, pack, *regs) - - -def ptx_tcgen05_st(dst_addr, *regs, shape, num, row=0, col=0, unpack=False): - """TVM intrinsic for tcgen05.st.sync.aligned — async collective store to TMEM. - - Emits ``tcgen05.st.sync.aligned.{shape}.x{num}[.unpack::16b].b32 [addr], {regs};`` - - Parameters - ---------- - dst_addr : PrimExpr - Tensor-memory destination address (uint32). - - regs : list[PrimExpr] - Source registers. Count depends on shape x num. - - shape : str - One of ``"16x32bx2"``, ``"16x64b"``, ``"16x128b"``, ``"16x256b"``, ``"32x32b"``. - - num : int - Repeat factor along the columns. Power-of-two in [1, 128]. - - row, col : PrimExpr - Optional TMEM row/col offsets added to ``dst_addr`` at runtime (row must be - a multiple of 32). Default 0. - - unpack : bool - Unpack a 32-bit register into two 16-bit chunks. - """ - _choice("shape", shape, _TCGEN05_LDST_SHAPES) - return call_intrin("", "tirx.ptx_tcgen05_st", dst_addr, row, col, shape, num, unpack, *regs) - - -def ptx_tcgen05_wait_ld(): - """TVM intrinsic to call tcgen05.wait::ld.sync.aligned - Wait for the completion of all prior async tcgen05.ld operations. - """ - return call_intrin("", "tirx.ptx_tcgen05_wait_ld") - - -def ptx_tcgen05_wait_st(): - """TVM intrinsic to call tcgen05.wait::st.sync.aligned - Wait for the completion of all prior async tcgen05.st operations. - """ - return call_intrin("", "tirx.ptx_tcgen05_wait_st") - - -def ptx_tcgen05_commit(bar, cta_group=1, cta_mask=0, *, pred=None): - """TVM intrinsic to call tcgen05.commit.cta_group - - Parameters - ---------- - bar : PrimExpr - The pointer to mbarrier variable. - - cta_group: int - The number of CTA groups involved in previous tcgen05 operations. - - cta_mask : int - The mask of the CTAs in the cluster, used for multicast. - - pred : Optional[PrimExpr] - Runtime ``uint32`` predicate. When given, emit - ``@p tcgen05.commit...`` with ``p = (pred != 0)``. This preserves - PTX-level instruction predicate semantics (single predicated - instruction in SASS), distinct from a C-level ``if`` branch. - - Returns - ------- - call : PrimExpr - The call expression. - """ - _choice("cta_group", cta_group, _TCGEN05_CTA_GROUP) - args = [bar, cta_group, cta_mask] - if pred is not None: - args.append(pred) - return call_intrin("", "tirx.ptx_tcgen05_commit", *args) - - -def print_buffer(buffer_var, dtype, is_string, is_scalar, dim_num, *shape): - """Print out buffer memory (tensor, string, or scalar) during runtime on cuda. - This print function allows printing out buffer in tvm during runtime without - dumping all the cuda code. - Parameters - ---------- - buffer_var : Var - The data pointer of the buffer that needs to be printed out. - dtype : DataType - The data type of the buffer. - is_string: Bool - Whether the buffer is a string (dtype is Int8 by default in the backend). - is_scalar: Bool - Whether the buffer is a scalar. - dim_num : Int - The number of dimensions of the buffer - *shape : Tuple - The dimensions of the buffer in order. - Returns - ------- - call : PrimExpr - The call expression. - """ - final_shape_args = [] - if len(shape) == 1 and isinstance(shape[0], tuple | list | tvm.ir.Array): - # Case 1: Called as print_buffer(..., dim, (s1, s2, ...)) - # The user provided a tuple/list as the single shape argument. - final_shape_args = list(shape[0]) - else: - # Case 2: Called as print_buffer(..., dim, s1, s2, ...) - # This is how TVMScript parser will call it. - final_shape_args = list(shape) - - return _ffi_api.print_buffer( - buffer_var, dtype, is_string, is_scalar, dim_num, *final_shape_args - ) - - -def timer_init_cuda(profiler_buffer, profiler_tag, profiler_write_offset, num_groups, group_id): - """TVM intrinsic for initializing the CUDA profiler, and store profiling result in a buffer. - - Parameters - ---------- - profiler_buffer: Var - The buffer to store the profiling result. - - profiler_tag: Var - Buffer of length 1 storing the base tag of the current thread. - - profiler_write_offset: Var - Buffer of length 1 storing the offset in buffer to write the next - profiling result for the current thread. - - num_groups: int - The number of groups in the profiler. - - group_id: PrimExpr - The group id of the current thread. - - Returns - ------- - call : PrimExpr - The call expression. - """ - - return call_intrin( - "handle", - "tirx.timer_init_cuda", - profiler_buffer, - profiler_tag, - profiler_write_offset, - num_groups, - group_id, - ) - - -def timer_start_cuda( - event_type, - profiler_buffer, - profiler_tag, - profiler_write_offset, - profiler_write_stride, - leader_cond, -): - """TVM intrinsic for starting the timer for profiling a specific event, and storing profiling result in a buffer. - - Parameters - ---------- - event_type: Enum - The event to profile. - - profiler_buffer: Var - The buffer to store the profiling result. - - profiler_tag: Var - Buffer of length 1 storing the base tag of the current thread. - - profiler_write_offset: Var - Buffer of length 1 storing the offset in buffer to write the next - profiling result for the current thread. - - profiler_write_stride: int - The stride to advance in buffer in the next write. - - leader_cond: PrimExpr - The condition to check if the current thread is the leader. - - Returns - ------- - call : PrimExpr - The call expression. - """ # noqa: E501 - - return call_intrin( - "handle", - "tirx.timer_start_cuda", - event_type.value, - profiler_buffer, - profiler_tag, - profiler_write_offset, - profiler_write_stride, - leader_cond, - ) - - -def timer_end_cuda( - event_type, - profiler_buffer, - profiler_tag, - profiler_write_offset, - profiler_write_stride, - leader_cond, -): - """TVM intrinsic for ending the timer for profiling a specific event, and storing profiling result in a buffer. - - Parameters - ---------- - event_type: Enum - The event to profile. - - profiler_buffer: Var - The buffer to store the profiling result. - - profiler_tag: Var - Buffer of length 1 storing the base tag of the current thread. - - profiler_write_offset: Var - Buffer of length 1 storing the offset in buffer to write the next - profiling result for the current thread. - - profiler_write_stride: int - The stride to advance in buffer in the next write. - - leader_cond: PrimExpr - The condition to check if the current thread is the leader. - - Returns - ------- - call : PrimExpr - The call expression. - """ # noqa: E501 - - return call_intrin( - "handle", - "tirx.timer_end_cuda", - event_type.value, - profiler_buffer, - profiler_tag, - profiler_write_offset, - profiler_write_stride, - leader_cond, - ) - - -def timer_finalize_cuda( - profiler_buffer, profiler_tag, profiler_write_offset, profiler_write_stride, leader_cond -): - """TVM intrinsic for finalizing the CUDA profiler, and store profiling result in a buffer. - - Parameters - ---------- - profiler_buffer: Var - The buffer to store the profiling result. - - profiler_tag: Var - Buffer of length 1 storing the base tag of the current thread. - - profiler_write_offset: Var - Buffer of length 1 storing the offset in buffer to write the next - profiling result for the current thread. - - profiler_write_stride: int - The stride to advance in buffer in the next write. - - leader_cond: PrimExpr - The condition to check if the current thread is the leader. - - Returns - ------- - call : PrimExpr - The call expression. - """ - - return call_intrin( - "handle", - "tirx.timer_finalize_cuda", - profiler_buffer, - profiler_tag, - profiler_write_offset, - profiler_write_stride, - leader_cond, - ) - - -def cuda_atomic_add(res_addr, value): - """TVM intrinsic to call cuda atomic add instruction - - Parameters - ---------- - res_addr : PrimExpr - The result address. - - value: PrimExpr - The value to add. - - Returns - ------- - call : PrimExpr - The call expression. - """ - value = tir.convert(value) - return call_intrin(value.dtype, "tirx.cuda_atomic_add", res_addr, value) - - -def cuda_thread_fence(): - """TVM intrinsic to call cuda thread fence instruction - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.cuda_thread_fence") - - -def cuda_warpgroup_sync(bar_no): - """TVM intrinsic to synchronize a CUDA warpgroup via a named barrier. - - Parameters - ---------- - bar_no : PrimExpr - The named barrier id to use for the warpgroup. - - Notes - ----- - Synchronizes 128 threads in a warpgroup using `bar.sync bar_no, 128`. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.cuda_warpgroup_sync", bar_no) - - -def cuda_syncthreads_and(cond): - """TVM intrinsic to call cuda syncthreads_and instruction - - Parameters - ---------- - cond: PrimExpr - The condition. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("int64", "tirx.cuda_syncthreads_and", cond) - - -def cuda_syncthreads_or(cond): - """TVM intrinsic to call cuda syncthreads_or instruction - - Parameters - ---------- - cond: PrimExpr - The condition. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("int64", "tirx.cuda_syncthreads_or", cond) - - -def cuda_nano_sleep(time): - """TVM intrinsic to call cuda nano sleep instruction - - Parameters - ---------- - time: PrimExpr - The time to sleep. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.cuda_nano_sleep", time) - - -def cuda_printf(fmt, *args): - """TVM intrinsic to call cuda printf instruction - - Parameters - ---------- - fmt: str - The format string. - - *args: list - The arguments to the format string. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.cuda_printf", fmt, *args) - - -def cuda_ldg(addr, dtype): - """TVM intrinsic to call CUDA C++ __ldg() function - - Parameters - ---------- - addr : PrimExpr - The memory address to load. - - dtype : str - The data type of the loaded value. - - Returns - """ - return call_intrin(dtype, "tirx.cuda_ldg", addr, dtype) - - -def cuda_get_tmem_addr(addr, row_offset, col_offset): - """TVM intrinsic to call cuda tmem address calculation - - Parameters - ---------- - addr: PrimExpr - The memory address to calculate. - - row_offset: PrimExpr - The row offset to calculate. - - col_offset: PrimExpr - The column offset to calculate. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("uint32", "tirx.cuda_get_tmem_addr", addr, row_offset, col_offset) - - -def cuda_cvta_generic_to_shared(ptr): - """Convert a generic pointer to a shared-memory address (uint32). - - Wraps ``__cvta_generic_to_shared(ptr)``. Used by op-wrappers that - precompute the shared-memory address at the wrapper layer instead of - inside the asm helper body. - """ - return call_intrin("uint32", "tirx.cuda_cvta_generic_to_shared", ptr) - - -def cuda_smem_addr_from_uint64(cluster_addr): - """Narrow a 64-bit cluster-mapped SMEM address to a 32-bit SMEM address. - - Wraps ``static_cast(cluster_addr)``. Used by - cp.async.bulk.shared::cluster.* op-wrappers. - """ - return call_intrin("uint32", "tirx.cuda_smem_addr_from_uint64", cluster_addr) - - -def cuda_sm100_tma_2sm_mbarrier_addr(bar): - """Compute the SM100 2SM TMA mbarrier shared-address operand.""" - return bitwise_and(cuda_cvta_generic_to_shared(bar), const(0xFEFFFFFF, dtype="uint32")) - - -def ptx_exp2(x): - """TVM intrinsic for PTX fast exp2 approximation (ex2.approx.ftz.f32) - - Parameters - ---------- - x : PrimExpr - The float32 input value. - - Returns - ------- - call : PrimExpr - The call expression returning 2^x (approximate). - """ - return call_intrin("float32", "tirx.ptx_exp2", x) - - -def ptx_rcp(x): - """TVM intrinsic for PTX fast reciprocal approximation (rcp.approx.ftz.f32) - - Parameters - ---------- - x : PrimExpr - The float32 input value. - - Returns - ------- - call : PrimExpr - The call expression returning 1/x (approximate). - """ - return call_intrin("float32", "tirx.ptx_rcp", x) - - -def ptx_any_sync(mask, pred): - """TVM intrinsic for PTX warp-wide any predicate (__any_sync) - - Parameters - ---------- - mask : PrimExpr - The thread mask (uint32). - pred : PrimExpr - The predicate value (int32). - - Returns - ------- - call : PrimExpr - The call expression returning 1 if any thread in mask has pred != 0. - """ - return call_intrin("int32", "tirx.ptx_any_sync", mask, pred) - - -def ptx_reduce3_max_f32(a, b, c): - """TVM intrinsic to call 3-input max.f32 PTX instruction (sm_100a+) - - Parameters - ---------- - a, b, c : PrimExpr - The three float32 values to compare. - - Returns - ------- - call : PrimExpr - The call expression returning max(a, b, c). - """ - return call_intrin("float32", "tirx.ptx_reduce3_max_f32", a, b, c) - - -def ptx_reduce3_min_f32(a, b, c): - """TVM intrinsic to call 3-input min.f32 PTX instruction (sm_100a+) - - Parameters - ---------- - a, b, c : PrimExpr - The three float32 values to compare. - - Returns - ------- - call : PrimExpr - The call expression returning min(a, b, c). - """ - return call_intrin("float32", "tirx.ptx_reduce3_min_f32", a, b, c) - - -def _ptx_binary_arith(op_name, dtype, d, a, b, *, rounding="rn", ftz=False, sat=False): - """Shared helper for add/sub/mul over (f32 | f32x2 | f64), DPS form.""" - _choice("rounding", rounding, _F32X2_ROUND) - if dtype == "f64" and (ftz or sat): - raise ValueError(f"PTX {op_name}.f64 does not accept .ftz or .sat") - if dtype == "f32x2" and sat: - raise ValueError(f"PTX {op_name}.f32x2 does not accept .sat") - return call_intrin( - "", - f"tirx.ptx_{op_name}_{dtype}", - d, - a, - b, - rounding, - int(ftz), - int(sat), - ) - - -def _ptx_fma(dtype, d, a, b, c, *, rounding="rn", ftz=False, sat=False): - """Shared helper for fma over (f32 | f32x2 | f64), DPS form.""" - _choice("rounding", rounding, _F32X2_ROUND) - if dtype == "f64" and (ftz or sat): - raise ValueError("PTX fma.f64 does not accept .ftz or .sat") - if dtype == "f32x2" and sat: - raise ValueError("PTX fma.f32x2 does not accept .sat") - return call_intrin( - "", - f"tirx.ptx_fma_{dtype}", - d, - a, - b, - c, - rounding, - int(ftz), - int(sat), - ) - - -def ptx_add_f32(d_addr, a, b, *, rounding="rn", ftz=False, sat=False): - """PTX ``add{.rnd}{.ftz}{.sat}.f32 [d_addr], a, b`` — DPS form.""" - return _ptx_binary_arith("add", "f32", d_addr, a, b, rounding=rounding, ftz=ftz, sat=sat) - - -def ptx_add_f32x2(d_addr, a, b, *, rounding="rn", ftz=False): - """PTX ``add{.rnd}{.ftz}.f32x2 [d_addr], a, b`` — DPS form. - - a, b are packed-as-uint64 register operands (2 fp32 each). - """ - return _ptx_binary_arith("add", "f32x2", d_addr, a, b, rounding=rounding, ftz=ftz) - - -def ptx_add_f64(d_addr, a, b, *, rounding="rn"): - """PTX ``add{.rnd}.f64 [d_addr], a, b`` — DPS form (no .ftz / .sat).""" - return _ptx_binary_arith("add", "f64", d_addr, a, b, rounding=rounding) - - -def ptx_sub_f32(d_addr, a, b, *, rounding="rn", ftz=False, sat=False): - """PTX ``sub{.rnd}{.ftz}{.sat}.f32 [d_addr], a, b`` — DPS form.""" - return _ptx_binary_arith("sub", "f32", d_addr, a, b, rounding=rounding, ftz=ftz, sat=sat) - - -def ptx_sub_f32x2(d_addr, a, b, *, rounding="rn", ftz=False): - """PTX ``sub{.rnd}{.ftz}.f32x2 [d_addr], a, b`` — DPS form.""" - return _ptx_binary_arith("sub", "f32x2", d_addr, a, b, rounding=rounding, ftz=ftz) - - -def ptx_sub_f64(d_addr, a, b, *, rounding="rn"): - """PTX ``sub{.rnd}.f64 [d_addr], a, b`` — DPS form.""" - return _ptx_binary_arith("sub", "f64", d_addr, a, b, rounding=rounding) - - -def ptx_mul_f32(d_addr, a, b, *, rounding="rn", ftz=False, sat=False): - """PTX ``mul{.rnd}{.ftz}{.sat}.f32 [d_addr], a, b`` — DPS form.""" - return _ptx_binary_arith("mul", "f32", d_addr, a, b, rounding=rounding, ftz=ftz, sat=sat) - - -def ptx_mul_f32x2(d_addr, a, b, *, rounding="rn", ftz=False): - """PTX ``mul{.rnd}{.ftz}.f32x2 [d_addr], a, b`` — DPS form.""" - return _ptx_binary_arith("mul", "f32x2", d_addr, a, b, rounding=rounding, ftz=ftz) - - -def ptx_mul_f64(d_addr, a, b, *, rounding="rn"): - """PTX ``mul{.rnd}.f64 [d_addr], a, b`` — DPS form.""" - return _ptx_binary_arith("mul", "f64", d_addr, a, b, rounding=rounding) - - -def ptx_fma_f32(d_addr, a, b, c, *, rounding="rn", ftz=False, sat=False): - """PTX ``fma{.rnd}{.ftz}{.sat}.f32 [d_addr], a, b, c`` — DPS form.""" - return _ptx_fma("f32", d_addr, a, b, c, rounding=rounding, ftz=ftz, sat=sat) - - -def ptx_fma_f32x2(d_addr, a, b, c, *, rounding="rn", ftz=False): - """PTX ``fma{.rnd}{.ftz}.f32x2 [d_addr], a, b, c`` — DPS form. - - a, b, c are packed-as-uint64 register operands. - """ - return _ptx_fma("f32x2", d_addr, a, b, c, rounding=rounding, ftz=ftz) - - -def ptx_fma_f64(d_addr, a, b, c, *, rounding="rn"): - """PTX ``fma{.rnd}.f64 [d_addr], a, b, c`` — DPS form.""" - return _ptx_fma("f64", d_addr, a, b, c, rounding=rounding) - - -def ptx_max_f32(a, b, *, ftz=False, nan=False): - """TVM intrinsic for PTX ``max{.ftz}{.NaN}.f32 d, a, b``. - - 2-operand form (distinct from :func:`ptx_reduce3_max_f32` which is the - 3-operand SM_100+ form). ``.NaN`` qualifier propagates NaN inputs to - the output; without it, NaN inputs are silently ignored. - - Parameters - ---------- - a, b : PrimExpr - Float32 inputs. - ftz : bool - If True, flush subnormals to zero (``.ftz``). - nan : bool - If True, propagate NaN inputs (``.NaN``). - """ - return call_intrin("float32", "tirx.ptx_max_f32", a, b, int(ftz), int(nan)) - - -def ptx_griddepcontrol_wait(): - """TVM intrinsic for PTX ``griddepcontrol.wait`` (sm_90+). - - Blocks the current grid until prerequisite grids signalled via - :func:`ptx_griddepcontrol_launch_dependents` have finished. Acts as a - full memory barrier. - """ - return call_intrin("", "tirx.ptx_griddepcontrol_wait") - - -def ptx_griddepcontrol_launch_dependents(): - """TVM intrinsic for PTX ``griddepcontrol.launch_dependents`` (sm_90+). - - Signals that the current grid has reached a point where dependent - grids may begin execution. - """ - return call_intrin("", "tirx.ptx_griddepcontrol_launch_dependents") - - -_PTX_LD_SCOPE = {"cta", "cluster", "gpu", "sys"} -_PTX_LD_SPACE = {"global", "shared", "shared::cta", "shared::cluster", "local"} -_PTX_LD_VOLATILE_SPACE = _PTX_LD_SPACE | {"const"} -_PTX_LD_TYPE = {"b32", "u32", "u64", "s32", "f32"} -_PTX_LD_COP = {"", "ca", "cg", "cs", "lu", "cv"} -_PTX_MEM_SCOPE = {"", "cta", "cluster", "gpu", "sys"} -_PTX_MEM_SPACE = {"global", "shared", "shared::cta", "shared::cluster"} -_PTX_SCALAR_TYPE = {"b32", "b64", "u32", "u64", "s32", "s64", "f32", "f64"} -_PTX_RED_OP = {"and", "or", "xor", "add", "inc", "dec", "min", "max"} -_PTX_ATOM_OP = {"and", "or", "xor", "exch", "add", "inc", "dec", "min", "max"} -_PTX_ST_VEC = {"", "v2", "v4", "v8"} -_PTX_ST_COP = {"", "wb", "cg", "cs", "wt"} -_PTX_PREFETCH_TENSORMAP_SPACE = {"", "const", "param"} -_PTX_SCALAR_RETURN_TYPE = { - "b32": "uint32", - "u32": "uint32", - "s32": "int32", - "b64": "uint64", - "u64": "uint64", - "s64": "int64", - "f32": "float32", - "f64": "float64", -} -_PTX_CACHE_POLICY = { - "evict_normal": 0x1000000000000000, - "evict_first": 0x12F0000000000000, - "evict_last": 0x14F0000000000000, -} - - -def _resolve_cache_policy(cache_hint, cache_policy, choices=_CP_ASYNC_BULK_CACHE_HINT): - _choice("cache_hint", cache_hint, choices) - if cache_policy is not None: - return cache_policy, True - if cache_hint: - if cache_hint not in _PTX_CACHE_POLICY: - raise ValueError( - f"Unsupported built-in cache policy {cache_hint!r}; pass cache_policy explicitly" - ) - return const(_PTX_CACHE_POLICY[cache_hint], dtype="uint64"), True - return const(0, dtype="uint64"), False - - -def ptx_ld_acquire(addr, return_type, ptx_type, *, scope="gpu", space="global"): - """TVM intrinsic for scalar PTX ``ld.acquire.scope{.ss}.type`` loads. - - This wrapper covers the scalar no-cache-policy/no-vector instances of the - PTX ISA ``ld.acquire`` form. ``scope``, state ``space``, PTX ``type`` and - TVM ``return_type`` are explicit so callers can request either raw-bit or - typed loads. - - Parameters - ---------- - addr : PrimExpr - The memory address to load. - - return_type : str - TVM dtype returned by the load. - - ptx_type : str - PTX type suffix such as ``"b32"``, ``"u64"``, or ``"s32"``. - - scope : str - PTX memory scope: ``"cta"``, ``"cluster"``, ``"gpu"``, or ``"sys"``. - - space : str - PTX state space suffix. - - Returns - ------- - call : PrimExpr - The loaded value. - """ - _choice("scope", scope, _PTX_LD_SCOPE) - _choice("space", space, _PTX_LD_SPACE) - _choice("ptx_type", ptx_type, _PTX_LD_TYPE) - return call_intrin( - return_type, "tirx.ptx_ld_acquire", addr, return_type, ptx_type, scope, space - ) - - -def ptx_ld( - addr, - return_type, - ptx_type, - *, - weak=False, - space="global", - cop="", - cache_hint="", - cache_policy=None, -): - """TVM intrinsic for scalar PTX ``ld{.weak}{.ss}{.cop}{.level::cache_hint}.type``. - - This wrapper covers scalar no-prefetch/no-vector instances of the weak - generic load form. - """ - _choice("space", space, _PTX_LD_SPACE | {"const", "param::entry", "param::func"}) - _choice("cop", cop, _PTX_LD_COP) - _choice("ptx_type", ptx_type, _PTX_LD_TYPE) - cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) - return call_intrin( - return_type, - "tirx.ptx_ld", - addr, - cache_policy, - return_type, - int(bool(weak)), - space, - cop, - ptx_type, - int(has_cache_policy), - ) - - -def ptx_ld_volatile(addr, return_type, ptx_type, *, space="global"): - """TVM intrinsic for scalar PTX ``ld.volatile{.ss}.type`` loads. - - This wrapper covers scalar no-prefetch/no-vector instances. - """ - _choice("space", space, _PTX_LD_VOLATILE_SPACE) - _choice("ptx_type", ptx_type, _PTX_LD_TYPE) - return call_intrin(return_type, "tirx.ptx_ld_volatile", addr, return_type, ptx_type, space) - - -def ptx_ld_global_acquire(res, addr): - """TVM intrinsic to call the legacy ptx ld.global.acquire helper. - - Parameters - ---------- - res : PrimExpr - The result of the load. - - addr : PrimExpr - The memory address to load. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.ptx_ld_global_acquire", res, addr) - - -def ptx_red_scalar( - address, - value, - *, - sem="", - scope="", - space="global", - op, - ptx_type, - cache_hint="", - cache_policy=None, -): - _choice("scope", scope, _PTX_MEM_SCOPE) - _choice("space", space, _PTX_MEM_SPACE) - _choice("op", op, _PTX_RED_OP) - _choice("ptx_type", ptx_type, _PTX_SCALAR_TYPE) - cache_policy, has_cache_policy = _resolve_cache_policy( - cache_hint, cache_policy, _CP_ASYNC_CACHE_HINT - ) - if sem not in ("", "relaxed", "release"): - raise ValueError(f"Unsupported PTX red sem {sem!r}") - return call_intrin( - "", - "tirx.ptx_red_scalar", - address, - value, - cache_policy, - sem, - scope, - space, - op, - ptx_type, - int(has_cache_policy), - ) - - -def ptx_atom_scalar( - address, - value, - *, - sem="", - scope="", - space="global", - op, - ptx_type, - cache_hint="", - cache_policy=None, -): - _choice("scope", scope, _PTX_MEM_SCOPE) - _choice("space", space, _PTX_MEM_SPACE) - _choice("op", op, _PTX_ATOM_OP) - _choice("ptx_type", ptx_type, _PTX_SCALAR_TYPE) - cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) - if sem not in ("", "relaxed", "acquire", "release", "acq_rel"): - raise ValueError(f"Unsupported PTX atom sem {sem!r}") - return call_intrin( - _PTX_SCALAR_RETURN_TYPE[ptx_type], - "tirx.ptx_atom_scalar", - address, - value, - cache_policy, - sem, - scope, - space, - op, - ptx_type, - int(has_cache_policy), - ) - - -def ptx_st( - address, - *values, - weak=False, - space="shared", - cop="", - vec="", - ptx_type, - cache_hint="", - cache_policy=None, -): - _choice("space", space, _PTX_MEM_SPACE | {"local", "param::func"}) - _choice("cop", cop, _PTX_ST_COP) - _choice("vec", vec, _PTX_ST_VEC) - _choice("ptx_type", ptx_type, _PTX_SCALAR_TYPE) - cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) - return call_intrin( - "", - "tirx.ptx_st", - address, - *values, - cache_policy, - int(bool(weak)), - space, - cop, - vec, - ptx_type, - int(has_cache_policy), - ) - - -def ptx_st_bulk(ptr, num_bytes, *, weak=False, space="shared::cta"): - if space not in ("", "shared::cta"): - raise ValueError(f"Unsupported PTX st.bulk space {space!r}") - return call_intrin("", "tirx.ptx_st_bulk", ptr, num_bytes, int(bool(weak)), space) - - -def ptx_prefetch_tensormap(tensormap_addr, space=""): - _choice("space", space, _PTX_PREFETCH_TENSORMAP_SPACE) - return call_intrin("", "tirx.ptx_prefetch_tensormap", tensormap_addr, space) - - -def ptx_mbarrier_test_wait_parity(barrier, phase, *, sem="", scope="", space="shared::cta"): - if sem not in ("", "acquire", "relaxed"): - raise ValueError(f"Unsupported mbarrier.test_wait.parity sem {sem!r}") - if scope not in ("", "cta", "cluster"): - raise ValueError(f"Unsupported mbarrier.test_wait.parity scope {scope!r}") - if bool(sem) != bool(scope): - raise ValueError("mbarrier.test_wait.parity sem and scope must be set together") - if space not in ("shared", "shared::cta"): - raise ValueError(f"Unsupported mbarrier.test_wait.parity space {space!r}") - return call_intrin( - "uint32", "tirx.ptx_mbarrier_test_wait_parity", barrier, phase, sem, scope, space - ) - - -def ptx_cp_async_bulk_g2s_cta( - dst_ptr, - src_ptr, - num_bytes, - mbarrier_ptr, - *, - cache_hint="", - cache_policy=None, - ignore_oob=False, - ignore_bytes_left=0, - ignore_bytes_right=0, -): - cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) - return call_intrin( - "", - "tirx.ptx_cp_async_bulk_g2s_cta", - dst_ptr, - src_ptr, - num_bytes, - ignore_bytes_left, - ignore_bytes_right, - mbarrier_ptr, - cache_policy, - int(has_cache_policy), - int(bool(ignore_oob)), - ) - - -def ptx_cp_async_bulk_g2s_cluster( - dst_ptr, - src_ptr, - num_bytes, - mbarrier_ptr, - *, - cache_hint="", - cache_policy=None, - multicast=False, - cta_mask=0, -): - cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) - return call_intrin( - "", - "tirx.ptx_cp_async_bulk_g2s_cluster", - dst_ptr, - src_ptr, - num_bytes, - mbarrier_ptr, - cta_mask, - cache_policy, - int(has_cache_policy), - int(bool(multicast)), - ) - - -def ptx_cp_async_bulk_s2s_cluster(dst_ptr, src_ptr, num_bytes, mbarrier): - return call_intrin( - "", "tirx.ptx_cp_async_bulk_s2s_cluster", dst_ptr, src_ptr, num_bytes, mbarrier - ) - - -def ptx_cp_async_bulk_s2g( - dst_ptr, src_ptr, num_bytes, *, cache_hint="", cache_policy=None, cp_mask=False, byte_mask=0 -): - cache_policy, has_cache_policy = _resolve_cache_policy(cache_hint, cache_policy) - return call_intrin( - "", - "tirx.ptx_cp_async_bulk_s2g", - dst_ptr, - src_ptr, - num_bytes, - byte_mask, - cache_policy, - int(has_cache_policy), - int(bool(cp_mask)), - ) - - -def ptx_fns_b32(mask, base, offset): - return call_intrin("uint32", "tirx.ptx_fns_b32", mask, base, offset) - - -def ptx_add_rn_f32_bf16(acc, x): - return call_intrin("float32", "tirx.ptx_add_rn_f32_bf16", acc, x) - - -def cuda_uint_as_float(bits): - return call_intrin("float32", "tirx.cuda_uint_as_float", bits) - - -def cuda_float_as_uint(x): - return call_intrin("uint32", "tirx.cuda_float_as_uint", x) - - -def cuda_ballot_sync(mask, pred): - return call_intrin("uint32", "tirx.cuda_ballot_sync", mask, pred) - - -def cuda_ffs_u32(value): - return call_intrin("int32", "tirx.cuda_ffs_u32", value) - - -def cuda_reduce_add_sync_u32(mask, value): - return call_intrin("uint32", "tirx.cuda_reduce_add_sync_u32", mask, value) - - -def cuda_reduce_min_sync_u32(mask, value): - return call_intrin("uint32", "tirx.cuda_reduce_min_sync_u32", mask, value) - - -def cuda_clock64(): - return call_intrin("uint64", "tirx.cuda_clock64") - - -def cuda_make_float2(x, y): - return call_intrin("uint64", "tirx.cuda_make_float2", x, y) - - -def cuda_float2_x(packed): - return call_intrin("float32", "tirx.cuda_float2_x", packed) - - -def cuda_float2_y(packed): - return call_intrin("float32", "tirx.cuda_float2_y", packed) - - -def cuda_fmul2_rn(a, b): - return call_intrin("uint64", "tirx.cuda_fmul2_rn", a, b) - - -def cuda_fadd2_rn(a, b): - return call_intrin("uint64", "tirx.cuda_fadd2_rn", a, b) - - -def cuda_float22bfloat162_rn(v0, v1): - return call_intrin("uint32", "tirx.cuda_float22bfloat162_rn", v0, v1) - - -def cuda_float22bfloat162_rn_from_float2(packed): - return call_intrin("uint32", "tirx.cuda_float22bfloat162_rn_from_float2", packed) - - -def cuda_bfloat1622float2(packed): - return call_intrin("uint64", "tirx.cuda_bfloat1622float2", packed) - - -def cuda_hmin2(a, b): - return call_intrin("uint32", "tirx.cuda_hmin2", a, b) - - -def cuda_hmax2(a, b): - return call_intrin("uint32", "tirx.cuda_hmax2", a, b) - - -def cuda_fp8x4_e4m3_from_float4(x, y, z, w): - return call_intrin("uint32", "tirx.cuda_fp8x4_e4m3_from_float4", x, y, z, w) - - -def ptx_map_shared_rank(ptr, rank): - """TVM intrinsic to call ptx map_shared_rank instruction - - Parameters - ---------- - ptr: PrimExpr - The generic pointer to the local shared memory, handle type - - rank: int - The rank of the distributed shared memory. - - Returns - ------- - call : PrimExpr - The call expression. - """ - - return ptx_mapa(ptr, rank, space="", ptx_type="u64", return_type="uint64") - - -def ptx_mapa(ptr, rank, *, space="", ptx_type="u64", return_type="uint64"): - """TVM intrinsic for PTX ``mapa{.space}.type d, a, b``.""" - if space not in ("", "shared::cluster"): - raise ValueError(f"Unsupported mapa space {space!r}") - if ptx_type not in ("u32", "u64"): - raise ValueError(f"Unsupported mapa type {ptx_type!r}") - return call_intrin(return_type, "tirx.ptx_mapa", ptr, rank, space, ptx_type, return_type) - - -def cuda_atomic_cas(ptr, old_val, new_val): - """TVM intrinsic to call cuda atomic cas instruction - - Parameters - ---------- - ptr: PrimExpr - The pointer to the memory location. - - old_val: PrimExpr - The old value. - - new_val: PrimExpr - The new value. - - Returns - ------- - call : PrimExpr - The call expression. - """ - old_val = tir.convert(old_val) - return call_intrin(old_val.dtype, "tirx.cuda_atomic_cas", ptr, old_val, new_val) - - -def thread_return(): - """TVM intrinsic to call thread_return() - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.thread_return") - - -def continue_loop(span=None): - """Create a tir intrinsic call to represent continue expression - - Parameters - ---------- - span : Optional[Span] - The location of this operator in the source code. - - Returns - ------- - ret : PrimExpr - The continue expression - """ - - return _ffi_api.continue_loop(span) - - -def break_loop(span=None): - """Create a tir intrinsic call to represent break expression - - Parameters - ---------- - span : Optional[Span] - The location of this operator in the source code. - - Returns - ------- - ret : PrimExpr - The break expression - """ - - return _ffi_api.break_loop(span) - - -######################################################## -# NVSHMEM builtins -######################################################## - - -def nvshmem_my_pe(): - """TVM intrinsic to call nvshmem_my_pe() - - Returns - ------- - call : PrimExpr - The call expression. - """ - - return call_intrin("int32", "tirx.nvshmem_my_pe") - - -def nvshmem_n_pes(): - """TVM intrinsic to call nvshmem_n_pes() - - Returns - ------- - call : PrimExpr - The call expression. - """ - - return call_intrin("int32", "tirx.nvshmem_n_pes") - - -def nvshmem_getmem_nbi(dst, src, nelems, pe): - """TVM intrinsic to call nvshmem_getmem_nbi() - - Parameters - ---------- - dst: PrimExpr - The pointer to the symmetric address or host/device address of the data object to be updated. - - src: PrimExpr - The pointer to the symmetric address of the source data object. - - nelems: int - The number of bytes to get per thread. - - pe: int - The PE number of the remote PE. - - Returns - ------- - call : PrimExpr - The call expression. - """ # noqa: E501 - - return call_intrin("", "tirx.nvshmem_getmem_nbi", dst, src, nelems, pe) - - -def nvshmem_putmem_nbi(dst, src, nelems, pe): - """TVM intrinsic to call nvshmem_putmem_nbi() - - Parameters - ---------- - dst: PrimExpr - The pointer to the symmetric address of the destination data object. - - src: PrimExpr - The pointer to the symmetric address or host/device address of the data object to be copied. - - nelems: int - The number of bytes to put per thread. - - pe: int - The PE number of the remote PE. - - Returns - ------- - call : PrimExpr - The call expression. - """ - - return call_intrin("", "tirx.nvshmem_putmem_nbi", dst, src, nelems, pe) - - -def nvshmem_getmem_nbi_warp(dst, src, nelems, pe): - """TVM intrinsic to call nvshmem_getmem_nbi_warp() - - Parameters - ---------- - dst: PrimExpr - The pointer to the symmetric address or host/device address of the data object to be updated. - - src: PrimExpr - The pointer to the symmetric address of the source data object. - - nelems: int - The number of bytes to get per warp. - - pe: int - The PE number of the remote PE. - - Returns - ------- - call : PrimExpr - The call expression. - """ # noqa: E501 - - return call_intrin("", "tirx.nvshmem_getmem_nbi_warp", dst, src, nelems, pe) - - -def nvshmem_putmem_nbi_warp(dst, src, nelems, pe): - """TVM intrinsic to call nvshmem_putmem_nbi_warp() - - Parameters - ---------- - dst: PrimExpr - The pointer to the symmetric address of the destination data object. - - src: PrimExpr - The pointer to the symmetric address or host/device address of the data object to be copied. - - nelems: int - The number of bytes to put per warp. - - pe: int - The PE number of the remote PE. - - Returns - ------- - call : PrimExpr - The call expression. - """ - - return call_intrin("", "tirx.nvshmem_putmem_nbi_warp", dst, src, nelems, pe) - - -def nvshmem_getmem_nbi_block(dst, src, nelems, pe): - """TVM intrinsic to call nvshmem_getmem_nbi_block() - - Parameters - ---------- - dst: PrimExpr - The pointer to the symmetric address or host/device address of the data object to be updated. - - src: PrimExpr - The pointer to the symmetric address of the source data object. - - nelems: int - The number of bytes to get per block. - - pe: int - The PE number of the remote PE. - - Returns - ------- - call : PrimExpr - The call expression. - """ # noqa: E501 - - return call_intrin("", "tirx.nvshmem_getmem_nbi_block", dst, src, nelems, pe) - - -def nvshmem_putmem_nbi_block(dst, src, nelems, pe): - """TVM intrinsic to call nvshmem_putmem_nbi_block() - - Parameters - ---------- - dst: PrimExpr - The pointer to the symmetric address of the destination data object. - - src: PrimExpr - The pointer to the symmetric address or host/device address of the data object to be copied. - - nelems: int - The number of bytes to put per block. - - pe: int - The PE number of the remote PE. - - Returns - ------- - call : PrimExpr - The call expression. - """ - - return call_intrin("", "tirx.nvshmem_putmem_nbi_block", dst, src, nelems, pe) - - -def nvshmem_signal_op(sig_addr, signal, sig_op, pe): - """TVM intrinsic to call nvshmem_signal_op() - - Parameters - ---------- - sig_addr: PrimExpr - The pointer to the symmetric address of the signal word to be updated, must be uint64_t*. - - signal: uint64_t - The value used to update sig_addr. - - sig_op: str - Operation used to update sig_addr with signal, typical sig_op values are "set" and "add". - - pe: int - The PE number of the remote PE. - - Returns - ------- - call : PrimExpr - The call expression. - """ - - _choice("sig_op", sig_op, _NVSHMEM_SIG_OP) - return call_intrin("", "tirx.nvshmem_signal_op", sig_addr, signal, sig_op, pe) - - -def nvshmem_wait_until(ivar, cmp, cmp_value, type="uint64_t"): - """TVM intrinsic to call nvshmem_wait_until() - - Parameters - ---------- - ivar: PrimExpr - The pointer to the symmetric address of a remotely accessible data object, must be TYPE*. - - cmp: str - The compare operator that compares ivar with cmp_value. - - cmp_value: TYPE - The value to be compared with ivar. - - type: str - The TYPE of ivar and cmp_value. - - Returns - ------- - call : PrimExpr - The call expression. - """ - - _choice("cmp", cmp, _NVSHMEM_CMP) - return call_intrin("", "tirx.nvshmem_wait_until", ivar, cmp, cmp_value, type) - - -def nvshmem_quiet(): - """TVM intrinsic to call nvshmem_quiet() - - Returns - ------- - call : PrimExpr - The call expression. - """ - - return call_intrin("", "tirx.nvshmem_quiet") - - -def nvshmem_putmem_signal_nbi(dst, src, nelems, sig_addr, signal, sig_op, pe): - """TVM intrinsic to call nvshmem_putmem_signal_nbi() - - Parameters - ---------- - dst: PrimExpr - The pointer to the symmetric address of the data object to be updated on the remote PE. - - src: PrimExpr - The pointer to the symmetric address or host/device address of data object containing the data to be copied. - - nelems: int - The number of bytes to put per thread. - - sig_addr: PrimExpr - The pointer to the symmetric address of the signal data object to be updated on the remote PE as a signal, must be uint64_t*. - - signal: uint64_t - The unsigned 64-bit value that is used for updating the remote sig_addr signal data object. - - sig_op: str - Signal operator that represents the type of update to be performed on the remote sig_addr signal data object. - - pe: int - The PE number of the remote PE. - - Returns - ------- - call : PrimExpr - The call expression. - """ # noqa: E501 - - return call_intrin( - "", "tirx.nvshmem_putmem_signal_nbi", dst, src, nelems, sig_addr, signal, sig_op, pe - ) - - -def nvshmem_putmem_signal_nbi_warp(dst, src, nelems, sig_addr, signal, sig_op, pe): - """TVM intrinsic to call nvshmem_putmem_signal_nbi_warp() - - Parameters - ---------- - dst: PrimExpr - The pointer to the symmetric address of the data object to be updated on the remote PE. - - src: PrimExpr - The pointer to the symmetric address or host/device address of data object containing the data to be copied. - - nelems: int - The number of bytes to put per warp. - - sig_addr: PrimExpr - The pointer to the symmetric address of the signal data object to be updated on the remote PE as a signal, must be uint64_t*. - - signal: uint64_t - The unsigned 64-bit value that is used for updating the remote sig_addr signal data object. - - sig_op: str - Signal operator that represents the type of update to be performed on the remote sig_addr signal data object. - - pe: int - The PE number of the remote PE. - - Returns - ------- - call : PrimExpr - The call expression. - """ # noqa: E501 - - return call_intrin( - "", "tirx.nvshmem_putmem_signal_nbi_warp", dst, src, nelems, sig_addr, signal, sig_op, pe - ) - - -def nvshmem_putmem_signal_nbi_block(dst, src, nelems, sig_addr, signal, sig_op, pe): - """TVM intrinsic to call nvshmem_putmem_signal_nbi_block() - - Parameters - ---------- - dst: PrimExpr - The pointer to the symmetric address of the data object to be updated on the remote PE. - - src: PrimExpr - The pointer to the symmetric address or host/device address of data object containing the data to be copied. - - nelems: int - The number of bytes to put per block. - - sig_addr: PrimExpr - The pointer to the symmetric address of the signal data object to be updated on the remote PE as a signal, must be uint64_t*. - - signal: uint64_t - The unsigned 64-bit value that is used for updating the remote sig_addr signal data object. - - sig_op: str - Signal operator that represents the type of update to be performed on the remote sig_addr signal data object. - - pe: int - The PE number of the remote PE. - - Returns - ------- - call : PrimExpr - The call expression. - """ # noqa: E501 - - return call_intrin( - "", "tirx.nvshmem_putmem_signal_nbi_block", dst, src, nelems, sig_addr, signal, sig_op, pe - ) - - -def nvshmem_fence(): - """TVM intrinsic to call nvshmem_fence() - - Returns - ------- - call : PrimExpr - The call expression. - """ - - return call_intrin("", "tirx.nvshmem_fence") - - -def nvshmem_barrier_all(): - """TVM intrinsic to call nvshmem_barrier_all() - - Returns - ------- - call : PrimExpr - The call expression. - """ - - return call_intrin("", "tirx.nvshmem_barrier_all") - - -######################################################## -# NKI builtins -######################################################## - - -def nki_load(res, data): - """TVM intrinsic to call nki load instruction - - Parameters - ---------- - res : BufferLoad - The result buffer. - - data: BufferLoad - The data buffer. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.nki_load", res, data) - - -def nki_store(res, data): - """TVM intrinsic to call nki store instruction - - Parameters - ---------- - res : BufferLoad - The result buffer. - - data: BufferLoad - The data buffer. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.nki_store", res, data) - - -def nki_tensor_copy(res, data): - """TVM intrinsic to call nki tensor copy instruction - - Parameters - ---------- - res : BufferLoad - The result buffer. - - data: BufferLoad - The data buffer. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.nki_tensor_copy", res, data) - - -def nki_matmul(res, lhs, rhs, accum=True): - """TVM intrinsic to call nki matmul instruction - - Parameters - ---------- - res : BufferLoad - The result buffer. - - lhs: BufferLoad - The left hand side buffer. - - rhs: BufferLoad - The right hand side buffer. - - accum: bool - Whether to accumulate the result. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.nki_matmul", res, lhs, rhs, accum) - - -def nki_activation(result, data, opcode, bias=0.0, scale=1.0): - """TVM intrinsic to call nki activation instruction - - Parameters - ---------- - result : BufferLoad - The result buffer. - - data: BufferLoad - The data buffer. - - opcode: str - The opcode. - - bias: PrimExpr - The bias. - - scale: PrimExpr - The scale. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.nki_activation", result, data, opcode, bias, scale) - - -def nki_reciprocal(result, data): - """TVM intrinsic to call nki reciprocal instruction - - Parameters - ---------- - result : BufferLoad - The result buffer. - - data: BufferLoad - The data buffer. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.nki_reciprocal", result, data) - - -def nki_tensorreduce(result, data, opcode, negate, *axes): - """TVM intrinsic to call nki tensorreduce instruction - - Parameters - ---------- - result : BufferLoad - The result buffer. - - data: BufferLoad - The data buffer. - - opcode: str - The opcode. - - negate: bool - Whether to negate the result. - - axes: Tuple[int] - The axes to reduce over. - - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.nki_tensorreduce", result, data, opcode, negate, *axes) - - -def nki_tensortensor(result, operand0, operand1, opcode): - """TVM intrinsic to call nki tensortensor instruction - - Parameters - ---------- - result : BufferLoad - The result buffer. - - operand0: BufferLoad - The first operand buffer. - - operand1: BufferLoad - The second operand buffer. - - opcode: str - The opcode. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.nki_tensortensor", result, operand0, operand1, opcode) - - -def nki_tensorscalar(result, operand0, operand1, opcode, reverse=False): - """TVM intrinsic to call nki tensorscalar instruction - - Parameters - ---------- - result : BufferLoad - The result buffer. - - operand0: BufferLoad - The first operand buffer. - - operand1: PrimExpr - The second operand scalar. - - opcode: str - The opcode. - - reverse: bool - Whether to reverse the operands. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.nki_tensorscalar", result, operand0, operand1, opcode, reverse) - - -def nki_memset(result, value): - """TVM intrinsic to call nki memset instruction - - Parameters - ---------- - result : BufferLoad - The result buffer. - - value: PrimExpr - The value to set. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.nki_memset", result, value) - - -def nki_activation_reduce(reduce_res, act_res, data, opcode, reduce_opcode, bias=0.0, scale=1.0): - """TVM intrinsic to call nki activation reduce instruction - - act_res = act_op(data * scale + bias) - reduce_res = reduce_op(act_res) - - Parameters - ---------- - reduce_res : BufferLoad - The result buffer of reduction. - - act_res : BufferLoad - The result buffer of activation. - - data: BufferLoad - The data buffer. - - opcode: str - The opcode. - - reduce_opcode: str - The reduce opcode. - - bias: PrimExpr - The bias. - - scale: PrimExpr - The scale. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin( - "", - "tirx.nki_activation_reduce", - reduce_res, - act_res, - data, - opcode, - reduce_opcode, - bias, - scale, - ) - - -def nki_tensorscalar_reduce( - reduce_res, tensorscalar_res, operand0, operand1, opcode, reduce_opcode, reverse=False -): - """TVM intrinsic to call nki tensorscalar reduce instruction - - tensorscalar_res = tensorscalar_op(operand0, operand1) - reduce_res = reduce_op(tensorscalar_res) - - Parameters - ---------- - reduce_res : BufferLoad - The result buffer of reduction. - - tensorscalar_res : BufferLoad - The result buffer of tensorscalar. - - operand0: BufferLoad - The first operand buffer. - - operand1: PrimExpr - The second operand scalar. - - opcode: str - The opcode. - - reduce_opcode: str - The reduce opcode. - - reverse: bool - Whether to reverse the operands of tensorscalar. - """ - return call_intrin( - "", - "tirx.nki_tensorscalar_reduce", - reduce_res, - tensorscalar_res, - operand0, - operand1, - opcode, - reduce_opcode, - reverse, - ) - - -def nki_identity(result, size): - """TVM intrinsic to call nki identity instruction - - Parameters - ---------- - result : BufferLoad - The result buffer. - - size: PrimExpr - The size of the identity tensor. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.nki_identity", result, size) - - -def nki_scalar_tensor_tensor( - result, data, operand0, operand1, opcode0, opcode1, reverse0=False, reverse1=False -): - """TVM intrinsic to call nki scalar tensor tensor instruction - (data op0 operand0) op1 (operand1) , where op0 is tensor-scalar and op1 is tensor-tensor - - Parameters - ---------- - result : BufferLoad - The result buffer. - - data: BufferLoad - The data buffer. - - operand0: PrimExpr - The first operand scalar. - - operand1: BufferLoad - The second operand buffer. - - opcode0: str - The first opcode. - - opcode1: str - The second opcode. - - reverse0: bool - Whether to reverse the first operand. - - reverse1: bool - Whether to reverse the second operand. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin( - "", - "tirx.nki_scalar_tensor_tensor", - result, - data, - operand0, - operand1, - opcode0, - opcode1, - reverse0, - reverse1, - ) - - -def nki_scalar_tensor_scalar( - result, data, operand0, operand1, opcode0, opcode1, reverse0=False, reverse1=False -): - """TVM intrinsic to call nki scalar tensor scalar instruction - (data op0 operand0) op1 (operand1) , where op0 and op1 are tensor-scalar - - Parameters - ---------- - result : BufferLoad - The result buffer. - - data: BufferLoad - The data buffer. - - operand0: PrimExpr - The first operand scalar. - - operand1: PrimExpr - The second operand scalar. - - opcode0: str - The first opcode. - - opcode1: str - The second opcode. - - reverse0: bool - Whether to reverse the first operand. - - reverse1: bool - Whether to reverse the second operand. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin( - "", - "tirx.nki_scalar_tensor_scalar", - result, - data, - operand0, - operand1, - opcode0, - opcode1, - reverse0, - reverse1, - ) - - -def nki_affine_select(result, pred, true_value, false_value): - """TVM intrinsic to call nki affine select instruction - - Parameters - ---------- - result : BufferLoad - The result buffer. - - pred: PrimExpr - The predicate. - - true_value: PrimExpr - The true value. - - false_value: PrimExpr - The false value. - - Returns - ------- - call : PrimExpr - The call expression. - """ - return call_intrin("", "tirx.nki_affine_select", result, pred, true_value, false_value) + return _ffi_api.break_loop(span) diff --git a/python/tvm/tirx/operator/intrinsics/_common.py b/python/tvm/tirx/operator/intrinsics/_common.py index 6a0509e83795..f85fefc4c64d 100644 --- a/python/tvm/tirx/operator/intrinsics/_common.py +++ b/python/tvm/tirx/operator/intrinsics/_common.py @@ -17,7 +17,7 @@ """Shared enum / value tables for PTX intrinsic schemas and user wrappers. Single source of truth. Both ``tvm.tirx.op`` (user wrappers that validate -arguments via ``_choice``) and ``tvm.tirx.operator.intrinsics.cuda.*`` +arguments via ``_choice``) and ``tvm.tirx.cuda.operator.intrinsics.*`` (schema declarations using ``Choice(choices=...)`` / ``IntAttr(choices=...)``) import from here. diff --git a/python/tvm/tirx/operator/tile_primitive/__init__.py b/python/tvm/tirx/operator/tile_primitive/__init__.py index f1e6dda01272..4f9aa93d0200 100644 --- a/python/tvm/tirx/operator/tile_primitive/__init__.py +++ b/python/tvm/tirx/operator/tile_primitive/__init__.py @@ -22,16 +22,9 @@ # code refers to the same ops. from .ops import * -# Dispatch infrastructure + per-target schedule registrations. +# Dispatch infrastructure. Per-backend schedule registrations are loaded via +# ``tvm.backend.load()``. from .dispatcher import fail, list_registered_schedules, predicate, register_dispatch from .registry import DispatchContext -from .cuda.copy import * -from .cuda.reduction import * -from .cuda.copy_async import * -from .cuda.permute_layout import * -from .cuda.gemm import * -from .cuda.gemm_async import * -from .cuda.elementwise import * -from .trn import * __all__ = ["DispatchContext", "fail", "list_registered_schedules", "predicate", "register_dispatch"] diff --git a/python/tvm/tirx/operator/tile_primitive/dispatch_context.py b/python/tvm/tirx/operator/tile_primitive/dispatch_context.py index 79fbcce8c843..b6bfad133329 100644 --- a/python/tvm/tirx/operator/tile_primitive/dispatch_context.py +++ b/python/tvm/tirx/operator/tile_primitive/dispatch_context.py @@ -169,6 +169,10 @@ def is_trn(self) -> bool: """Check if the target is Trainium.""" return self.target.kind.name == "trn" + def is_target(self, name: str) -> bool: + """Check if the target kind matches ``name``.""" + return self.target.kind.name == name + # -- scope predicates ---------------------------------------------------- # # Each ``is_`` returns True iff the op site is at that scope kind. diff --git a/python/tvm/tirx/script/builder/ir.py b/python/tvm/tirx/script/builder/ir.py index 754f41fa3b23..2c18c61136b8 100644 --- a/python/tvm/tirx/script/builder/ir.py +++ b/python/tvm/tirx/script/builder/ir.py @@ -2979,499 +2979,6 @@ def wrapped(*args, **kwargs): return wrapped -def _ptx_ldg32(reg, guard, addr, local_addr): - if isinstance(addr, Buffer): - addr = addr[0] - return _tir_op.call_intrin(reg.dtype, "tirx.ptx.ldg32", reg, guard, addr, local_addr) - - -_ptx_ldg32.__tir_op_name__ = "ptx.ldg32" - - -class PTXNamespace: - """The PTX instruction submodule.""" - - def __init__(self): - self.ldg32 = _ptx_ldg32 - self.ldmatrix = _dtype_forward(_tir_op.ptx_ldmatrix) - # Apache-compatible variant. Same lowered intrinsic as - # ``ldmatrix`` but accepts the historical ``(trans, num, dtype, - # local_ptr, local_offset, smem_ptr, smem_offset)`` form. Coexists - # with the fork-native version so upstream-derived tests keep - # working without rewriting their tirx code. - self.ldmatrix_legacy = _dtype_forward(_tir_op.ptx_ldmatrix_legacy) - self.stmatrix = _op_wrapper(_tir_op.ptx_stmatrix) - self.setmaxnreg: Callable[..., Any] = _op_wrapper(_tir_op.ptx_setmaxnreg) - self.elect_sync: Callable[..., Any] = _op_wrapper(_tir_op.ptx_elect_sync) - self.fetch_register: Callable[..., Any] = _op_wrapper(_tir_op.ptx_fetch_register) - self.ld = _op_wrapper(_tir_op.ptx_ld) - self.ld_acquire = _op_wrapper(_tir_op.ptx_ld_acquire) - self.ld_volatile = _op_wrapper(_tir_op.ptx_ld_volatile) - self.ld_global_acquire = _op_wrapper(_tir_op.ptx_ld_global_acquire) - self.red_scalar = _op_wrapper(_tir_op.ptx_red_scalar) - self.atom_scalar = _op_wrapper(_tir_op.ptx_atom_scalar) - self.prefetch_tensormap = _op_wrapper(_tir_op.ptx_prefetch_tensormap) - self.mbarrier_test_wait_parity = _op_wrapper(_tir_op.ptx_mbarrier_test_wait_parity) - self.cp_async_bulk_g2s_cta = _op_wrapper(_tir_op.ptx_cp_async_bulk_g2s_cta) - self.cp_async_bulk_g2s_cluster = _op_wrapper(_tir_op.ptx_cp_async_bulk_g2s_cluster) - self.cp_async_bulk_s2s_cluster = _op_wrapper(_tir_op.ptx_cp_async_bulk_s2s_cluster) - self.cp_async_bulk_s2g = _op_wrapper(_tir_op.ptx_cp_async_bulk_s2g) - self.st = _op_wrapper(_tir_op.ptx_st) - self.st_bulk = _op_wrapper(_tir_op.ptx_st_bulk) - self.fns_b32 = _op_wrapper(_tir_op.ptx_fns_b32) - self.add_rn_f32_bf16 = _op_wrapper(_tir_op.ptx_add_rn_f32_bf16) - self.mapa = _op_wrapper(_tir_op.ptx_mapa) - self.map_shared_rank = _op_wrapper(_tir_op.ptx_map_shared_rank) - self.any_sync = _op_wrapper(_tir_op.ptx_any_sync) - # Math operations - self.exp2 = _op_wrapper(_tir_op.ptx_exp2) - self.rcp = _op_wrapper(_tir_op.ptx_rcp) - self.reduce3_min_f32 = _op_wrapper(_tir_op.ptx_reduce3_min_f32) - self.reduce3_max_f32 = _op_wrapper(_tir_op.ptx_reduce3_max_f32) - # add/sub/mul/fma DPS form: (d_addr, a, b[, c], *, rounding, ftz[, sat]) - self.add_f32 = _op_wrapper(_tir_op.ptx_add_f32) - self.add_f32x2 = _op_wrapper(_tir_op.ptx_add_f32x2) - self.add_f64 = _op_wrapper(_tir_op.ptx_add_f64) - self.sub_f32 = _op_wrapper(_tir_op.ptx_sub_f32) - self.sub_f32x2 = _op_wrapper(_tir_op.ptx_sub_f32x2) - self.sub_f64 = _op_wrapper(_tir_op.ptx_sub_f64) - self.mul_f32 = _op_wrapper(_tir_op.ptx_mul_f32) - self.mul_f32x2 = _op_wrapper(_tir_op.ptx_mul_f32x2) - self.mul_f64 = _op_wrapper(_tir_op.ptx_mul_f64) - self.fma_f32 = _op_wrapper(_tir_op.ptx_fma_f32) - self.fma_f32x2 = _op_wrapper(_tir_op.ptx_fma_f32x2) - self.fma_f64 = _op_wrapper(_tir_op.ptx_fma_f64) - self.max_f32 = _op_wrapper(_tir_op.ptx_max_f32) - self.mma = MmaNamespace() - self.cp_async = CpAsyncNamespace() - self.wgmma = WgmmaNamespace() - self.mbarrier = MbarrierNamespace() - self.tcgen05 = Tcgen05Namespace() - self.bar = BarNamespace() - self.barrier = BarrierNamespace() - self.fence = FenceNamespace() - self.griddepcontrol = GriddepcontrolNamespace() - - -class MmaNamespace: - """The MMA instruction submodule.""" - - def __init__(self): - self.sp = _dtype_forward(_tir_op.ptx_mma_sp) - # Apache-compatible variant of ptx_mma. Coexists with the - # fork-native ``__call__`` form (``T.ptx.mma(...)``). - self.legacy = _dtype_forward(_tir_op.ptx_mma_legacy) - # __call__ corresponds to ptx_mma - self.__tir_call_op_name__ = "ptx_mma" - - def __call__(self, *args, **kwds): - return _dtype_forward(_tir_op.ptx_mma)(*args, **kwds) - - -class CpAsyncNamespace: - """The CpAsync instruction submodule.""" - - def __init__(self): - self.commit_group = _op_wrapper(_tir_op.ptx_cp_async_commit_group) - self.wait_group = _op_wrapper(_tir_op.ptx_cp_async_wait_group) - # Legacy variant: takes (dst_ptr, dst_offset, src_ptr, src_offset, - # cp_size). Offsets are folded into the pointers; coexists with - # the fork-native ``__call__`` form. - self.legacy = _dtype_forward(_tir_op.ptx_cp_async_legacy) - self.bulk = CpAsyncBulkNamespace() - self.mbarrier = CpAsyncMbarrierNamespace() - - def __call__(self, *args, **kwds): - # Accept the legacy 6-arg form ``(elem_dtype, dst, dst_off, src, - # src_off, cp_size)`` that the printer round-trips for the raw - # ``tirx.ptx_cp_async`` Call emitted by ``s_tir/transform/ - # InjectPTXAsyncCopy``. The pass-emitted Call has 5 args (no - # ``tvm_access_ptr`` fold) and a per-element-dtype Call.dtype, - # so build it directly. - if len(args) == 6 and isinstance(args[0], str) and "dtype" not in kwds: - import tvm - - elem_dtype, dst, dst_off, src, src_off, cp_size = args - return tvm.tirx.Call( - tvm.DataType(elem_dtype), - tvm.ir.Op.get("tirx.ptx_cp_async"), - [dst, dst_off, src, src_off, cp_size], - ) - return _dtype_forward(_tir_op.ptx_cp_async)(*args, **kwds) - - # __call__ corresponds to ptx_cp_async - __tir_call_op_name__ = "ptx_cp_async" - - -class CpAsyncBulkNamespace: - """The CpAsyncBulk instruction submodule.""" - - def __init__(self): - self.commit_group = _op_wrapper(_tir_op.ptx_cp_async_bulk_commit_group) - self.wait_group = _op_wrapper(_tir_op.ptx_cp_async_bulk_wait_group) - self.tensor = CpAsyncBulkTensorNamespace() - self.s2c = _op_wrapper(_tir_op.ptx_cp_async_bulk_shared_to_cluster) - - def __call__(self, *args, **kwds): - return _dtype_forward(_tir_op.ptx_cp_async_bulk)(*args, **kwds) - - # __call__ corresponds to ptx_cp_async_bulk - __tir_call_op_name__ = "ptx_cp_async_bulk" - - -class CpAsyncBulkTensorNamespace: - """The CpAsyncBulkTensor instruction submodule.""" - - def __init__(self): - self.g2c = _op_wrapper(_tir_op.ptx_cp_async_bulk_tensor_global_to_cluster) - self.g2c_tile_gather4 = _op_wrapper( - _tir_op.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster - ) - self.s2g = _op_wrapper(_tir_op.ptx_cp_async_bulk_tensor_shared_to_global) - self.s2g_reduce = _op_wrapper(_tir_op.ptx_cp_async_bulk_tensor_shared_to_global_reduce) - self.g2c_prefetch = _op_wrapper(_tir_op.ptx_cp_async_bulk_tensor_global_to_cluster_prefetch) - - @staticmethod - def g2c_bar_addr( - dim, - dst_ptr, - bar_addr, - tensormap_addr, - cta_mask, - cta_group, - cache_hint, - *coords, - cache_policy=None, - ): - _tir_op._choice("cta_group", cta_group, _tir_op._TCGEN05_CTA_GROUP) - cache_policy, has_cache_policy = _tir_op._resolve_cache_policy(cache_hint, cache_policy) - return _tir_op.call_intrin( - "", - "tirx.ptx_cp_async_bulk_tensor_global_to_cluster", - dim, - dst_ptr, - bar_addr, - tensormap_addr, - cta_mask, - cta_group, - cache_policy, - int(has_cache_policy), - 1, - *coords, - ) - - @staticmethod - def g2c_tile_gather4_bar_addr( - dim, - dst_ptr, - bar_addr, - tensormap_addr, - cta_mask, - cta_group, - cache_hint, - *coords, - cache_policy=None, - ): - _tir_op._choice("cta_group", cta_group, _tir_op._TCGEN05_CTA_GROUP) - cache_policy, has_cache_policy = _tir_op._resolve_cache_policy(cache_hint, cache_policy) - return _tir_op.call_intrin( - "", - "tirx.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster", - dim, - dst_ptr, - bar_addr, - tensormap_addr, - cta_mask, - cta_group, - cache_policy, - int(has_cache_policy), - 1, - *coords, - ) - - -class CpAsyncMbarrierNamespace: - """The CpAsyncMbarrier instruction submodule.""" - - def __init__(self): - self.arrive = _op_wrapper(_tir_op.ptx_cp_async_mbarrier_arrive) - - -class WgmmaNamespace: - """The WGMMA instruction submodule.""" - - def __init__(self): - self.fence: Callable[..., Any] = _op_wrapper(_tir_op.ptx_wgmma_fence) - self.commit_group = _op_wrapper(_tir_op.ptx_wgmma_commit_group) - self.wait_group = _op_wrapper(_tir_op.ptx_wgmma_wait_group) - self.noop_barrier = _op_wrapper(_tir_op.ptx_wgmma_noop_barrier) - self.mma_async = WgmmaMmaAsyncNamespace() - self.encode_matrix_descriptor = _op_wrapper(_tir_op.ptx_wgmma_encode_matrix_descriptor) - - -class WgmmaMmaAsyncNamespace: - """The WGMMA MMAAsync instruction submodule.""" - - def __init__(self): - self.ss = _op_wrapper(_tir_op.ptx_wgmma_mma_async_ss) - self.rs = _op_wrapper(_tir_op.ptx_wgmma_mma_async_rs) - - -class MbarrierNamespace: - """The Mbarrier instruction submodule.""" - - def __init__(self): - self.init = _op_wrapper(_tir_op.ptx_mbarrier_init) - self.try_wait = _op_wrapper(_tir_op.ptx_mbarrier_try_wait) - self.try_wait_once = _op_wrapper(_tir_op.ptx_mbarrier_try_wait_once) - self.arrive = MbarrierArriveNamespace() - - -class MbarrierArriveNamespace: - """The Mbarrier Arrive instruction submodule.""" - - def __init__(self): - self.expect_tx = _op_wrapper(_tir_op.ptx_mbarrier_arrive_expect_tx) - - def __call__(self, *args, **kwds): - return _op_wrapper(_tir_op.ptx_mbarrier_arrive)(*args, **kwds) - - # __call__ corresponds to ptx_mbarrier_arrive - __tir_call_op_name__ = "ptx_mbarrier_arrive" - - -class Tcgen05Namespace: - """The Tcgen05 instruction submodule.""" - - def __init__(self): - self.alloc = _op_wrapper(_tir_op.ptx_tcgen05_alloc) - self.dealloc = _op_wrapper(_tir_op.ptx_tcgen05_dealloc) - self.relinquish_alloc_permit = _op_wrapper(_tir_op.ptx_tcgen05_relinquish_alloc_permit) - self.encode_matrix_descriptor = _op_wrapper(_tir_op.ptx_tcgen05_encode_matrix_descriptor) - self.encode_instr_descriptor = _op_wrapper(_tir_op.ptx_tcgen05_encode_instr_descriptor) - self.encode_instr_descriptor_block_scaled = _op_wrapper( - _tir_op.ptx_tcgen05_encode_instr_descriptor_block_scaled - ) - self.ld = _op_wrapper(_tir_op.ptx_tcgen05_ld) - self.st = _op_wrapper(_tir_op.ptx_tcgen05_st) - self.cp = _op_wrapper(_tir_op.ptx_tcgen05_cp) - self.shift = _op_wrapper(_tir_op.ptx_tcgen05_shift) - self.commit = _op_wrapper(_tir_op.ptx_tcgen05_commit) - self.wait = Tcgen05WaitNamespace() - self.mma = Tcgen05MmaNamespace() - self.fence = Tcgen05FenceNamespace() - - -class Tcgen05FenceNamespace: - """The Tcgen05 Fence instruction submodule.""" - - def __init__(self): - self.before_thread_sync = _op_wrapper(_tir_op.ptx_tcgen05_fence_before_thread_sync) - self.after_thread_sync = _op_wrapper(_tir_op.ptx_tcgen05_fence_after_thread_sync) - - -class Tcgen05MmaNamespace: - """The Tcgen05 MMA instruction submodule.""" - - def __init__(self): - self.block_scale = _op_wrapper(_tir_op.ptx_tcgen05_mma_block_scale) - self.sp = Tcgen05MmaSpNamespace() - - def __call__(self, *args, **kwds): - return _op_wrapper(_tir_op.ptx_tcgen05_mma)(*args, **kwds) - - # __call__ corresponds to ptx_tcgen05_mma - __tir_call_op_name__ = "ptx_tcgen05_mma" - - -class Tcgen05MmaSpNamespace: - """Tcgen05 Sparse MMA instruction submodule.""" - - def __init__(self): - self.block_scale = _op_wrapper(_tir_op.ptx_tcgen05_mma_sp_block_scale) - - def __call__(self, *args, **kwds): - return _op_wrapper(_tir_op.ptx_tcgen05_mma_sp)(*args, **kwds) - - # __call__ corresponds to ptx_tcgen05_mma_sp - __tir_call_op_name__ = "ptx_tcgen05_mma_sp" - - -class Tcgen05WaitNamespace: - """The Tcgen05 Wait instruction submodule.""" - - def __init__(self): - self.ld = _op_wrapper(_tir_op.ptx_tcgen05_wait_ld) - self.st = _op_wrapper(_tir_op.ptx_tcgen05_wait_st) - - -class BarNamespace: - """The Bar instruction submodule.""" - - def __init__(self): - self.arrive = _op_wrapper(_tir_op.ptx_bar_arrive) - self.sync = _op_wrapper(_tir_op.ptx_bar_sync) - - -class BarrierNamespace: - """The Barrier instruction submodule.""" - - def __init__(self): - self.cluster = BarrierClusterNamespace() - - -class BarrierClusterNamespace: - """The BarrierCluster instruction submodule.""" - - def __init__(self): - self.arrive = _op_wrapper(_tir_op.ptx_barrier_cluster_arrive) - self.wait = _op_wrapper(_tir_op.ptx_barrier_cluster_wait) - - -class FenceNamespace: - """PTX fence instruction submodule.""" - - def __init__(self): - self.proxy_async = _op_wrapper(_tir_op.ptx_fence_proxy_async) - self.mbarrier_init = _op_wrapper(_tir_op.ptx_fence_mbarrier_init) - - def __call__(self, *args, **kwds): - return _op_wrapper(_tir_op.ptx_fence)(*args, **kwds) - - __tir_call_op_name__ = "ptx_fence" - - -class GriddepcontrolNamespace: - """PTX griddepcontrol instruction submodule (sm_90+).""" - - def __init__(self): - self.wait = _op_wrapper(_tir_op.ptx_griddepcontrol_wait) - self.launch_dependents = _op_wrapper(_tir_op.ptx_griddepcontrol_launch_dependents) - - -class CUDANamespace: - """The CUDA intrinsics submodule.""" - - def __init__(self): - self.atomic_add = _op_wrapper(_tir_op.cuda_atomic_add) - self.thread_fence = _op_wrapper(_tir_op.cuda_thread_fence) - self.warpgroup_sync = _op_wrapper(_tir_op.cuda_warpgroup_sync) - self.warp_sync = _op_wrapper(_tir_op.cuda_warp_sync) - self.warp_reduce = _op_wrapper(_tir_op.cuda_warp_reduce) - self.warp_sum = _op_wrapper(_tir_op.cuda_warp_sum) - self.warp_max = _op_wrapper(_tir_op.cuda_warp_max) - self.warp_min = _op_wrapper(_tir_op.cuda_warp_min) - self.cta_reduce = _op_wrapper(_tir_op.cuda_cta_reduce) - self.cta_sum = _op_wrapper(_tir_op.cuda_cta_sum) - self.cta_max = _op_wrapper(_tir_op.cuda_cta_max) - self.cta_min = _op_wrapper(_tir_op.cuda_cta_min) - self.copy_bytes = _op_wrapper(_tir_op.cuda_copy_bytes) - self.copy_128b = _op_wrapper(_tir_op.cuda_copy_128b) - self.copy_64b = _op_wrapper(_tir_op.cuda_copy_64b) - self.copy_32b = _op_wrapper(_tir_op.cuda_copy_32b) - self.copy_16b = _op_wrapper(_tir_op.cuda_copy_16b) - self.copy_8b = _op_wrapper(_tir_op.cuda_copy_8b) - self.cta_sync = _op_wrapper(_tir_op.cuda_cta_sync) - self.grid_sync = _op_wrapper(_tir_op.cuda_grid_sync) - self.cluster_sync = _op_wrapper(_tir_op.cuda_cluster_sync) - self.thread_rank = _op_wrapper(_tir_op.cuda_thread_rank) - self.trap_when_assert_failed = _op_wrapper(_tir_op.cuda_trap_when_assert_failed) - self.runtime_instr_desc = _op_wrapper(_tir_op.cuda_runtime_instr_desc) - self.half2float = _op_wrapper(_tir_op.cuda_half2float) - self.bfloat162float = _op_wrapper(_tir_op.cuda_bfloat162float) - self.float22half2 = _op_wrapper(_tir_op.cuda_float22half2) - self.half8tofloat8 = _op_wrapper(_tir_op.cuda_half8tofloat8) - self.float8tohalf8 = _op_wrapper(_tir_op.cuda_float8tohalf8) - self.syncthreads_and = _op_wrapper(_tir_op.cuda_syncthreads_and) - self.syncthreads_or = _op_wrapper(_tir_op.cuda_syncthreads_or) - self.nano_sleep = _op_wrapper(_tir_op.cuda_nano_sleep) - self.atomic_cas = _op_wrapper(_tir_op.cuda_atomic_cas) - self.func_call = _op_wrapper(_tir_op.cuda_func_call) - self.printf = _op_wrapper(_tir_op.cuda_printf) - self.ldg = _op_wrapper(_tir_op.cuda_ldg) - self.get_tmem_addr = _op_wrapper(_tir_op.cuda_get_tmem_addr) - self.cvta_generic_to_shared = _op_wrapper(_tir_op.cuda_cvta_generic_to_shared) - self.smem_addr_from_uint64 = _op_wrapper(_tir_op.cuda_smem_addr_from_uint64) - self.sm100_tma_2sm_mbarrier_addr = _op_wrapper(_tir_op.cuda_sm100_tma_2sm_mbarrier_addr) - self.uint_as_float = _op_wrapper(_tir_op.cuda_uint_as_float) - self.float_as_uint = _op_wrapper(_tir_op.cuda_float_as_uint) - self.ballot_sync = _op_wrapper(_tir_op.cuda_ballot_sync) - self.ffs_u32 = _op_wrapper(_tir_op.cuda_ffs_u32) - self.reduce_add_sync_u32 = _op_wrapper(_tir_op.cuda_reduce_add_sync_u32) - self.reduce_min_sync_u32 = _op_wrapper(_tir_op.cuda_reduce_min_sync_u32) - self.clock64 = _op_wrapper(_tir_op.cuda_clock64) - self.make_float2 = _op_wrapper(_tir_op.cuda_make_float2) - self.float2_x = _op_wrapper(_tir_op.cuda_float2_x) - self.float2_y = _op_wrapper(_tir_op.cuda_float2_y) - self.fmul2_rn = _op_wrapper(_tir_op.cuda_fmul2_rn) - self.fadd2_rn = _op_wrapper(_tir_op.cuda_fadd2_rn) - self.float22bfloat162_rn = _op_wrapper(_tir_op.cuda_float22bfloat162_rn) - self.float22bfloat162_rn_from_float2 = _op_wrapper( - _tir_op.cuda_float22bfloat162_rn_from_float2 - ) - self.bfloat1622float2 = _op_wrapper(_tir_op.cuda_bfloat1622float2) - self.hmin2 = _op_wrapper(_tir_op.cuda_hmin2) - self.hmax2 = _op_wrapper(_tir_op.cuda_hmax2) - self.fp8x4_e4m3_from_float4 = _op_wrapper(_tir_op.cuda_fp8x4_e4m3_from_float4) - setattr(self, "__shfl_sync", self._shfl_sync) - setattr(self, "__shfl_up_sync", self._shfl_up_sync) - setattr(self, "__shfl_down_sync", self._shfl_down_sync) - setattr(self, "__shfl_xor_sync", self._shfl_xor_sync) - setattr(self, "__activemask", self._activemask) - - @staticmethod - def _shfl_sync(mask, var, lane, width): - if isinstance(var, Buffer): - var = var[0] - return _tir_op.call_intrin(var.dtype, "tirx.cuda.__shfl_sync", mask, var, lane, width) - - @staticmethod - def _shfl_up_sync(mask, var, delta, width): - if isinstance(var, Buffer): - var = var[0] - return _tir_op.call_intrin(var.dtype, "tirx.cuda.__shfl_up_sync", mask, var, delta, width) - - @staticmethod - def _shfl_down_sync(mask, var, delta, width): - if isinstance(var, Buffer): - var = var[0] - return _tir_op.call_intrin(var.dtype, "tirx.cuda.__shfl_down_sync", mask, var, delta, width) - - @staticmethod - def _shfl_xor_sync(mask, var, lane_mask, width): - if isinstance(var, Buffer): - var = var[0] - return _tir_op.call_intrin( - var.dtype, "tirx.cuda.__shfl_xor_sync", mask, var, lane_mask, width - ) - - @staticmethod - def _activemask(): - return _tir_op.call_intrin("uint32", "tirx.cuda.__activemask") - - -class MetalNamespace: - """The Metal intrinsics submodule.""" - - @staticmethod - def simd_shuffle(var, lane): - if isinstance(var, Buffer): - var = var[0] - return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle", var, lane) - - @staticmethod - def simd_shuffle_up(var, delta): - if isinstance(var, Buffer): - var = var[0] - return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle_up", var, delta) - - @staticmethod - def simd_shuffle_down(var, delta): - if isinstance(var, Buffer): - var = var[0] - return _tir_op.call_intrin(var.dtype, "tirx.metal.simd_shuffle_down", var, delta) - - class WebGPUNamespace: """The WebGPU intrinsics submodule.""" @@ -3494,100 +3001,14 @@ def subgroup_shuffle_down(var, delta): return _tir_op.call_intrin(var.dtype, "tirx.webgpu.subgroup_shuffle_down", var, delta) -class NVSHMEMNamespace: - """The NVSHMEM intrinsics submodule.""" - - def __init__(self): - self.my_pe = _op_wrapper(_tir_op.nvshmem_my_pe) - self.n_pes = _op_wrapper(_tir_op.nvshmem_n_pes) - self.signal_op = _op_wrapper(_tir_op.nvshmem_signal_op) - self.wait_until = _op_wrapper(_tir_op.nvshmem_wait_until) - self.quiet = _op_wrapper(_tir_op.nvshmem_quiet) - self.fence = _op_wrapper(_tir_op.nvshmem_fence) - self.barrier_all = _op_wrapper(_tir_op.nvshmem_barrier_all) - self.getmem_nbi = NVSHMEMGetMemNBINamespace() - self.putmem_nbi = NVSHMEMPutMemNBINamespace() - self.putmem_signal_nbi = NVSHMEMPutMemSignalNBINamespace() - - -class NVSHMEMGetMemNBINamespace: - """The NVSHMEM GetMemNBI intrinsics submodule.""" - - def __init__(self): - self.warp = _op_wrapper(_tir_op.nvshmem_getmem_nbi_warp) - self.block = _op_wrapper(_tir_op.nvshmem_getmem_nbi_block) - - def __call__(self, *args, **kwds): - return _op_wrapper(_tir_op.nvshmem_getmem_nbi)(*args, **kwds) - - # __call__ corresponds to nvshmem_getmem_nbi - __tir_call_op_name__ = "nvshmem_getmem_nbi" - - -class NVSHMEMPutMemNBINamespace: - """The NVSHMEM PutMemNBI intrinsics submodule.""" - - def __init__(self): - self.warp = _op_wrapper(_tir_op.nvshmem_putmem_nbi_warp) - self.block = _op_wrapper(_tir_op.nvshmem_putmem_nbi_block) - - def __call__(self, *args, **kwds): - return _op_wrapper(_tir_op.nvshmem_putmem_nbi)(*args, **kwds) - - # __call__ corresponds to nvshmem_putmem_nbi - __tir_call_op_name__ = "nvshmem_putmem_nbi" - - -class NVSHMEMPutMemSignalNBINamespace: - """The NVSHMEM PutMemSignalNBI intrinsics submodule.""" - - def __init__(self): - self.warp = _op_wrapper(_tir_op.nvshmem_putmem_signal_nbi_warp) - self.block = _op_wrapper(_tir_op.nvshmem_putmem_signal_nbi_block) - - def __call__(self, *args, **kwds): - return _op_wrapper(_tir_op.nvshmem_putmem_signal_nbi)(*args, **kwds) - - # __call__ corresponds to nvshmem_putmem_signal_nbi - __tir_call_op_name__ = "nvshmem_putmem_signal_nbi" - - -class NKINamespace: - """The NKI instructions submodule.""" - - def __init__(self): - self.load = _op_wrapper(_tir_op.nki_load) - self.store = _op_wrapper(_tir_op.nki_store) - self.tensor_copy = _op_wrapper(_tir_op.nki_tensor_copy) - self.matmul = _op_wrapper(_tir_op.nki_matmul) - self.activation = _op_wrapper(_tir_op.nki_activation) - self.activation_reduce = _op_wrapper(_tir_op.nki_activation_reduce) - self.reciprocal = _op_wrapper(_tir_op.nki_reciprocal) - self.tensorreduce = _op_wrapper(_tir_op.nki_tensorreduce) - self.tensortensor = _op_wrapper(_tir_op.nki_tensortensor) - self.tensorscalar = _op_wrapper(_tir_op.nki_tensorscalar) - self.tensorscalar_reduce = _op_wrapper(_tir_op.nki_tensorscalar_reduce) - self.scalar_tensor_tensor = _op_wrapper(_tir_op.nki_scalar_tensor_tensor) - self.scalar_tensor_scalar = _op_wrapper(_tir_op.nki_scalar_tensor_scalar) - self.memset = _op_wrapper(_tir_op.nki_memset) - self.identity = _op_wrapper(_tir_op.nki_identity) - self.affine_select = _op_wrapper(_tir_op.nki_affine_select) - - -ptx = PTXNamespace() -cuda = CUDANamespace() -metal = MetalNamespace() webgpu = WebGPUNamespace() -nvshmem = NVSHMEMNamespace() -nki = NKINamespace() # -# Register printer namespace mapping from the builder namespaces -# so that the TVMScript printer emits T.cuda/T.ptx/T.nvshmem/T.nki dotted names. -# This keeps parser and printer consistent using a single registration source. +# Register printer namespace mapping from the builder namespaces so the +# TVMScript printer emits dotted names that match parser namespaces. # -def _register_tir_namespace_printer_names(): +def _register_script_namespace_printer_names(ns_obj, dotted_prefix): def register_printer_name(op_name, script_name): try: ir.Op.get(op_name) @@ -3625,13 +3046,38 @@ def visit(ns_obj, dotted_prefix): for full_op_name in {flat_name, _tir_op._canonical_device_intrin_name(flat_name)}: register_printer_name(full_op_name, script_name) + visit(ns_obj, dotted_prefix) + + +def register_script_namespace(name: str, namespace: object) -> object: + """Register a TVMScript namespace on the TIRx builder facade.""" + globals()[name] = namespace + if "__all__" in globals() and name not in __all__: + __all__.append(name) + + import sys # pylint: disable=import-outside-toplevel + + for module_name in [ + "tvm.tirx.script.builder", + "tvm.tirx.script.parser", + "tvm.tirx.script", + "tvm.script.tirx", + ]: + module = sys.modules.get(module_name) + if module is None: + continue + setattr(module, name, namespace) + module_all = getattr(module, "__all__", None) + if isinstance(module_all, list) and name not in module_all: + module_all.append(name) + + _register_script_namespace_printer_names(namespace, name) + return namespace + + +def _register_tir_namespace_printer_names(): try: - visit(ptx, "ptx") - visit(cuda, "cuda") - visit(metal, "metal") - visit(webgpu, "webgpu") - visit(nvshmem, "nvshmem") - visit(nki, "nki") + _register_script_namespace_printer_names(webgpu, "webgpu") except Exception: # Best-effort registration; avoid import-time hard failure pass @@ -3713,6 +3159,7 @@ def visit(ns_obj, dotted_prefix): tvm_access_ptr = _op_wrapper(_tir_op.tvm_access_ptr) ptr_byte_offset = _op_wrapper(_tir_op.ptr_byte_offset) tvm_throw_last_error = _op_wrapper(_tir_op.tvm_throw_last_error) +print_buffer = _op_wrapper(_tir_op.print_buffer) tvm_stack_alloca = _op_wrapper(_tir_op.tvm_stack_alloca) tvm_stack_make_shape = _op_wrapper(_tir_op.tvm_stack_make_shape) tvm_stack_make_array = _op_wrapper(_tir_op.tvm_stack_make_array) @@ -3739,10 +3186,6 @@ def visit(ns_obj, dotted_prefix): tvm_warp_shuffle_down = _tir_op.tvm_warp_shuffle_down tvm_warp_shuffle_xor = _tir_op.tvm_warp_shuffle_xor tvm_warp_activemask = _tir_op.tvm_warp_activemask -make_filled_simdgroup_matrix = _op_wrapper(_tir_op.make_filled_simdgroup_matrix) -simdgroup_load = _op_wrapper(_tir_op.simdgroup_load) -simdgroup_store = _op_wrapper(_tir_op.simdgroup_store) -simdgroup_multiply_accumulate = _op_wrapper(_tir_op.simdgroup_multiply_accumulate) cooperative_tensor_fill = _op_wrapper(_tir_op.cooperative_tensor_fill) cooperative_tensor_load = _op_wrapper(_tir_op.cooperative_tensor_load) cooperative_tensor_store = _op_wrapper(_tir_op.cooperative_tensor_store) @@ -3759,11 +3202,6 @@ def visit(ns_obj, dotted_prefix): anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked) vscale = _op_wrapper(_tir_op.vscale) ignore_loop_partition = _op_wrapper(_tir_op.ignore_loop_partition) -print_buffer = _op_wrapper(_tir_op.print_buffer) -timer_init_cuda = _op_wrapper(_tir_op.timer_init_cuda) -timer_start_cuda = _op_wrapper(_tir_op.timer_start_cuda) -timer_end_cuda = _op_wrapper(_tir_op.timer_end_cuda) -timer_finalize_cuda = _op_wrapper(_tir_op.timer_finalize_cuda) reinterpret = _dtype_forward(_tir_op.reinterpret) call_extern = _dtype_forward(_tir_op.call_extern) @@ -3771,10 +3209,6 @@ def visit(ns_obj, dotted_prefix): call_llvm_intrin = _dtype_forward(_tir_op.call_llvm_intrin) call_llvm_pure_intrin = _dtype_forward(_tir_op.call_llvm_pure_intrin) call_pure_extern = _dtype_forward(_tir_op.call_pure_extern) -mma_store = _dtype_forward(_tir_op.mma_store) -mma_fill = _dtype_forward(_tir_op.mma_fill) -mma_store_legacy = _dtype_forward(_tir_op.mma_store_legacy) -mma_fill_legacy = _dtype_forward(_tir_op.mma_fill_legacy) vectorlow = _dtype_forward(_tir_op.vectorlow) vectorhigh = _dtype_forward(_tir_op.vectorhigh) vectorcombine = _dtype_forward(_tir_op.vectorcombine) @@ -4040,6 +3474,7 @@ def visit(ns_obj, dotted_prefix): "tvm_access_ptr", "ptr_byte_offset", "tvm_throw_last_error", + "print_buffer", "tvm_stack_alloca", "tvm_stack_make_shape", "tvm_stack_make_array", @@ -4071,18 +3506,10 @@ def visit(ns_obj, dotted_prefix): "tvm_warp_shuffle_down", "tvm_warp_shuffle_xor", "tvm_warp_activemask", - "make_filled_simdgroup_matrix", - "simdgroup_load", - "simdgroup_store", - "simdgroup_multiply_accumulate", "cooperative_tensor_fill", "cooperative_tensor_load", "cooperative_tensor_store", "cooperative_tensor_multiply_accumulate", - "mma_store", - "mma_fill", - "mma_store_legacy", - "mma_fill_legacy", "vectorlow", "vectorhigh", "vectorcombine", @@ -4155,11 +3582,6 @@ def visit(ns_obj, dotted_prefix): "get_active_lane_mask", "call_kernel", "ignore_loop_partition", - "print_buffer", - "timer_init_cuda", - "timer_start_cuda", - "timer_end_cuda", - "timer_finalize_cuda", ] __all__ += [ @@ -4184,16 +3606,12 @@ def visit(ns_obj, dotted_prefix): "cta_id", "cta_id_in_cluster", "cta_id_in_pair", - "cuda", "decl_scalar", "device_entry", "lane_id", "local_scalar", "meta_class", - "metal", - "nki", - "nvshmem", - "ptx", + "register_script_namespace", "scalar_wrapper", "scope_id", "shared_scalar", diff --git a/python/tvm/tirx/transform/__init__.py b/python/tvm/tirx/transform/__init__.py index b0fcb5442da2..6b86a59e2f85 100644 --- a/python/tvm/tirx/transform/__init__.py +++ b/python/tvm/tirx/transform/__init__.py @@ -20,4 +20,3 @@ from .function_pass import prim_func_pass, PrimFuncPass from .transform import * -from . import trn diff --git a/python/tvm/tirx/transform/transform.py b/python/tvm/tirx/transform/transform.py index 297872eae6d1..72a5b96202d2 100644 --- a/python/tvm/tirx/transform/transform.py +++ b/python/tvm/tirx/transform/transform.py @@ -503,6 +503,17 @@ def Filter(fcond: Callable): return _ffi_api.Filter(fcond) # type: ignore +def TilePrimitiveDispatch(): + """Lower TIRx tile primitive calls through the active backend dispatch table. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.TilePrimitiveDispatch() # type: ignore + + def LowerTIRx(): """Lower TIR to a lower-level IR. diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc index fa2d7254ee38..8ff1a8b17ecf 100644 --- a/src/arith/const_int_bound.cc +++ b/src/arith/const_int_bound.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -792,12 +793,14 @@ class ConstIntBoundAnalyzer::Impl * topi.math.ceil_log2, and can appear in iteration bounds. */ static ffi::Optional FindCeilLog2Arg(const CastNode* op) { + static const Op& ceil_op = Op::Get("tirx.ceil"); + static const Op& log2_op = Op::Get("tirx.log2"); if (op->dtype.is_int()) { if (auto as_call = op->value.as()) { - if (as_call->op.same_as(Op::Get("tirx.ceil"))) { + if (as_call->op.same_as(ceil_op)) { PrimExpr ceil_arg = as_call->args[0]; if (auto arg_call = ceil_arg.as()) { - if (arg_call->op.same_as(Op::Get("tirx.log2"))) { + if (arg_call->op.same_as(log2_op)) { PrimExpr log_arg = arg_call->args[0]; return log_arg; } diff --git a/src/arith/ir_mutator_with_analyzer.cc b/src/arith/ir_mutator_with_analyzer.cc index 6fbd4c4551f3..8aa821fa453a 100644 --- a/src/arith/ir_mutator_with_analyzer.cc +++ b/src/arith/ir_mutator_with_analyzer.cc @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -215,10 +216,10 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const IfThenElseNode* op) { return constraint_scope_.WithNewScope([&]() -> Stmt { PrimExpr condition = this->VisitExpr(op->condition); PrimExpr real_condition = condition; - static auto op_likely = Op::Get("tirx.likely"); if (auto call = condition.as()) { - if (call->op.same_as(op_likely)) { + static const Op& likely_op = Op::Get("tirx.likely"); + if (call->op.same_as(likely_op)) { real_condition = call->args[0]; } } @@ -287,8 +288,8 @@ Stmt IRMutatorWithAnalyzer::VisitStmt_(const SeqStmtNode* op) { PrimExpr IRMutatorWithAnalyzer::VisitExpr_(const CallNode* op) { // add condition context to if_then_else - static auto op_if_then_else = Op::Get("tirx.if_then_else"); - if (op->op.same_as(op_if_then_else)) { + static const Op& if_then_else_op = Op::Get("tirx.if_then_else"); + if (op->op.same_as(if_then_else_op)) { PrimExpr cond = this->VisitExpr(op->args[0]); PrimExpr true_value, false_value; constraint_scope_.WithNewScope([&]() { diff --git a/src/arith/ir_visitor_with_analyzer.cc b/src/arith/ir_visitor_with_analyzer.cc index 7c0f2458fa77..269703d1a90f 100644 --- a/src/arith/ir_visitor_with_analyzer.cc +++ b/src/arith/ir_visitor_with_analyzer.cc @@ -22,6 +22,7 @@ */ #include "ir_visitor_with_analyzer.h" +#include #include #include #include @@ -96,8 +97,8 @@ void IRVisitorWithAnalyzer::VisitStmt_(const SeqStmtNode* op) { void IRVisitorWithAnalyzer::VisitExpr_(const CallNode* op) { // add condition context to if_then_else - static auto op_if_then_else = Op::Get("tirx.if_then_else"); - if (op->op.same_as(op_if_then_else)) { + static const Op& if_then_else_op = Op::Get("tirx.if_then_else"); + if (op->op.same_as(if_then_else_op)) { PrimExpr cond = op->args[0]; this->VisitExpr(op->args[0]); constraint_scope_.WithNewScope([&]() { diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 2120aaa1a859..bec509188311 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -2316,7 +2317,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { // the operator overload will eagerly constant fold. return op->args[0] << op->args[1]; } - } else if (op->op.same_as(Op::Get("tirx.ceil"))) { + } + static const Op& ceil_op = Op::Get("tirx.ceil"); + static const Op& log2_op = Op::Get("tirx.log2"); + static const Op& clz_op = Op::Get("tirx.clz"); + if (op->op.same_as(ceil_op)) { PrimExpr ceil_arg = op->args[0]; if (auto arg_int = op->args[0].as()) { return cast(op->dtype, IntImm(arg_int->dtype, arg_int->value)); @@ -2325,7 +2330,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { } else if (auto arg_call = ceil_arg.as()) { // ceil(log2(cast(n,"float64"))) is used as the implementation of // topi.math.ceil_log2, and appears in iteration bounds. - if (arg_call->op.same_as(Op::Get("tirx.log2"))) { + if (arg_call->op.same_as(log2_op)) { PrimExpr log_arg = arg_call->args[0]; if (auto as_float = log_arg.as()) { // ceil(log2(n)) can be simplified, and should produce the @@ -2335,7 +2340,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const CallNode* op) { } } } - } else if (op->op.same_as(Op::Get("tirx.clz"))) { + } else if (op->op.same_as(clz_op)) { if (const auto* arg_int = op->args[0].as()) { int bits = arg_int->dtype.bits(); if (arg_int->value == 0) return make_const(op->dtype, bits); diff --git a/src/target/cuda/codegen_cuda.cc b/src/backend/cuda/codegen/codegen_cuda.cc similarity index 96% rename from src/target/cuda/codegen_cuda.cc rename to src/backend/cuda/codegen/codegen_cuda.cc index 8c269fcc7e99..aa2ef63b149a 100644 --- a/src/target/cuda/codegen_cuda.cc +++ b/src/backend/cuda/codegen/codegen_cuda.cc @@ -25,6 +25,7 @@ #include #include +#include #include #include @@ -34,9 +35,9 @@ #include #include -#include "../../runtime/thread_storage_scope.h" -#include "../../tirx/transform/ir_utils.h" -#include "../build_common.h" +#include "../../../runtime/thread_storage_scope.h" +#include "../../../target/build_common.h" +#include "../../../tirx/transform/ir_utils.h" #include "cuda_fallback_module.h" #include "literal/cuda_half_t.h" #include "literal/cuda_int8_t.h" @@ -953,7 +954,25 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { } } - if (op->op.same_as(builtin::tvm_fill_fragment())) { + static const Op& tvm_fill_fragment_op = Op::Get("tirx.tvm_fill_fragment"); + static const Op& tvm_load_matrix_sync_op = Op::Get("tirx.tvm_load_matrix_sync"); + static const Op& tvm_store_matrix_sync_op = Op::Get("tirx.tvm_store_matrix_sync"); + static const Op& tvm_mma_sync_op = Op::Get("tirx.tvm_mma_sync"); + static const Op& tvm_bmma_sync_op = Op::Get("tirx.tvm_bmma_sync"); + static const Op& ptx_mma_op = Op::Get("tirx.ptx_mma"); + static const Op& ptx_mma_sp_op = Op::Get("tirx.ptx_mma_sp"); + static const Op& mma_store_op = Op::Get("tirx.mma_store"); + static const Op& mma_fill_op = Op::Get("tirx.mma_fill"); + static const Op& ptx_mma_legacy_op = Op::Get("tirx.ptx_mma_legacy"); + static const Op& ptx_ldmatrix_legacy_op = Op::Get("tirx.ptx_ldmatrix_legacy"); + static const Op& mma_store_legacy_op = Op::Get("tirx.mma_store_legacy"); + static const Op& mma_fill_legacy_op = Op::Get("tirx.mma_fill_legacy"); + static const Op& ptx_cp_async_bulk_op = Op::Get("tirx.ptx_cp_async_bulk"); + static const Op& ptx_cp_async_mbarrier_arrive_op = Op::Get("tirx.ptx_cp_async_mbarrier_arrive"); + static const Op& ptx_ldg32_op = Op::Get("tirx.ptx.ldg32"); + static const Op& cuda_func_call_op = Op::Get("tirx.cuda_func_call"); + + if (op->op.same_as(tvm_fill_fragment_op)) { codegen_tags_.insert("mma"); TVM_FFI_ICHECK_EQ(op->args.size(), 6U); os << "nvcuda::wmma::fill_fragment("; @@ -963,7 +982,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << "], "; this->PrintExpr(op->args[5], os); os << ")"; - } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { + } else if (op->op.same_as(tvm_load_matrix_sync_op)) { codegen_tags_.insert("mma"); TVM_FFI_ICHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::load_matrix_sync("; @@ -975,7 +994,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << ", "; this->PrintExpr(op->args[6], os); os << ")"; - } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { + } else if (op->op.same_as(tvm_store_matrix_sync_op)) { codegen_tags_.insert("mma"); TVM_FFI_ICHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::store_matrix_sync("; @@ -992,7 +1011,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { TVM_FFI_THROW(InternalError) << "Invalid parameters"; } os << ")"; - } else if (op->op.same_as(builtin::tvm_mma_sync())) { + } else if (op->op.same_as(tvm_mma_sync_op)) { codegen_tags_.insert("mma"); TVM_FFI_ICHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::mma_sync("; @@ -1002,7 +1021,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[i * 2 + 1], os); os << "]" << ((i < 3) ? ", " : ")"); } - } else if (op->op.same_as(builtin::tvm_bmma_sync())) { + } else if (op->op.same_as(tvm_bmma_sync_op)) { codegen_tags_.insert("mma"); TVM_FFI_ICHECK_EQ(op->args.size(), 8U); os << "nvcuda::wmma::bmma_sync("; @@ -1012,7 +1031,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->PrintExpr(op->args[i * 2 + 1], os); os << "]" << ((i < 3) ? ", " : ")"); } - } else if (IsOp(op, builtin::ptx_mma(), "tirx.ptx.mma")) { + } else if (IsOp(op, ptx_mma_op, "tirx.ptx.mma")) { // arg 0: shape: mXnXkX // arg 1: A layout: row/col // arg 2: B layout: row/col @@ -1047,7 +1066,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate); this->stream << asm_code; - } else if (IsOp(op, builtin::ptx_mma_sp(), "tirx.ptx.mma_sp")) { + } else if (IsOp(op, ptx_mma_sp_op, "tirx.ptx.mma_sp")) { // arg 0: shape: mXnXkX // arg 1: A layout: row/col // arg 2: B layout: row/col @@ -1085,7 +1104,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_offset, b_ref, b_offset, c_ref, c_offset, metadata, metadata_offset, sparse_selector, "", true, saturate); this->stream << asm_code; - } else if (op->op.same_as(builtin::mma_store())) { + } else if (op->op.same_as(mma_store_op)) { int m = Downcast(op->args[0])->value; int n = Downcast(op->args[1])->value; std::string dst = this->PrintExpr(op->args[2]); @@ -1135,7 +1154,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { << " + local_id];\n"; os << "}\n"; - } else if (op->op.same_as(builtin::mma_fill())) { + } else if (op->op.same_as(mma_fill_op)) { std::string num_elem = this->PrintExpr(op->args[0]); std::string dst = this->PrintExpr(op->args[1]); std::string dst_offset = this->PrintExpr(op->args[2]); @@ -1143,7 +1162,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << "for (int i = 0; i < " << num_elem << "; ++i) {\n"; os << dst << "[" << dst_offset << " + i] = 0.0;"; os << "}\n"; - } else if (IsOp(op, tvm::tirx::builtin::ptx_mma_legacy(), "tirx.ptx.mma_legacy")) { + } else if (IsOp(op, ptx_mma_legacy_op, "tirx.ptx.mma_legacy")) { // args: shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, // a_ptr_var, a_offset, b_ptr_var, b_offset, // c_ptr_var, c_offset, saturate, [bit_op] @@ -1166,7 +1185,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->stream << PrintMMAAssembly(shape, A_layout, B_layout, A_dtype, B_dtype, C_dtype, a_ref, a_bias, b_ref, b_bias, c_ref, c_bias, "", "", "", bit_op, false, saturate); - } else if (IsOp(op, tvm::tirx::builtin::ptx_ldmatrix_legacy(), "tirx.ptx.ldmatrix_legacy")) { + } else if (IsOp(op, ptx_ldmatrix_legacy_op, "tirx.ptx.ldmatrix_legacy")) { // args: trans, num, type, local_ptr_var, local_offset, smem_ptr_var, smem_offset codegen_tags_.insert("mma"); TVM_FFI_ICHECK_EQ(op->args.size(), 7U); @@ -1194,7 +1213,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { this->stream << PrintLoadMatrixAssembly(trans, num, type_str, local_ptr, local_offset, smem_ptr, smem_offset); } - } else if (op->op.same_as(tvm::tirx::builtin::mma_store_legacy())) { + } else if (op->op.same_as(mma_store_legacy_op)) { // args: m, n, dst_ptr, src_ptr_var, src_offset, dst_stride // (dst_ptr is typically an access_ptr Call that already encodes // dst.elem_offset and the global pointer cast.) @@ -1235,7 +1254,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << dst << "[" << this->PrintExpr(dst_ind) << "] = " << src << "[" << src_offset << " + local_id];\n"; os << "}\n"; - } else if (op->op.same_as(tvm::tirx::builtin::mma_fill_legacy())) { + } else if (op->op.same_as(mma_fill_legacy_op)) { // args: local_size, local_ptr_var, offset std::string num_elem = this->PrintExpr(op->args[0]); std::string dst = this->PrintExpr(op->args[1]); @@ -1243,7 +1262,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << "for (int i = 0; i < " << num_elem << "; ++i) {\n"; os << dst << "[" << dst_offset << " + i] = 0.0;"; os << "}\n"; - } else if (IsOp(op, builtin::ptx_cp_async_bulk(), "tirx.ptx.cp_async_bulk")) { + } else if (IsOp(op, ptx_cp_async_bulk_op, "tirx.ptx.cp_async_bulk")) { codegen_tags_.insert("cast_smem_ptr_to_int"); std::string dst = this->PrintExpr(op->args[0]); std::string dst_offset = this->PrintExpr(op->args[1]); @@ -1257,8 +1276,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string barrier_arr = barrier_name_ + "_" + std::to_string(barrier_arr_id); std::string barrier = barrier_arr + "[" + std::to_string(barrier_id) + "]"; this->stream << PrintCpAsyncBulkAsm(dst, dst_offset, src, src_offset, size, barrier); - } else if (IsOp(op, builtin::ptx_cp_async_mbarrier_arrive(), - "tirx.ptx.cp_async_mbarrier_arrive")) { + } else if (IsOp(op, ptx_cp_async_mbarrier_arrive_op, "tirx.ptx.cp_async_mbarrier_arrive")) { codegen_tags_.insert("cast_smem_ptr_to_int"); int barrier_arr_id = Downcast(op->args[0])->value; int barrier_id = Downcast(op->args[1])->value; @@ -1268,7 +1286,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { std::string barrier_arr = barrier_name_ + "_" + std::to_string(barrier_arr_id); std::string barrier = barrier_arr + "[" + std::to_string(barrier_id) + "]"; this->stream << PrintCpAsyncBarrierAsm(barrier); - } else if (IsOp(op, builtin::ptx_ldg32(), "tirx.ptx.ldg32")) { + } else if (IsOp(op, ptx_ldg32_op, "tirx.ptx.ldg32")) { /* asm volatile ( "{.reg .pred p;\n" @@ -1530,7 +1548,7 @@ void CodeGenCUDA::VisitExpr_(const CallNode* op, std::ostream& os) { os << "}\n" << "// print_buffer ends\n"; - } else if (op->op.same_as(builtin::cuda_func_call()) || + } else if (op->op.same_as(cuda_func_call_op) || (op->op.as() && op->op.as().value()->name == "tirx.cuda.func_call")) { print_cuda_func_call(op, os); } else if (op->op.same_as(builtin::thread_return())) { @@ -1554,7 +1572,8 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { TVM_FFI_ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; this->VisitStmt(op->body); - auto commit_group = Call(DataType::Void(), builtin::ptx_cp_async_commit_group(), {}); + static const Op& ptx_cp_async_commit_group_op = Op::Get("tirx.ptx_cp_async_commit_group"); + auto commit_group = Call(DataType::Void(), ptx_cp_async_commit_group_op, {}); this->PrintIndent(); this->VisitExpr(commit_group, this->stream); this->stream << ";\n"; @@ -1565,7 +1584,8 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { TVM_FFI_ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0."; auto wait_cnt = wait_attrs.second; - auto wait_group = Call(DataType::Void(), builtin::ptx_cp_async_wait_group(), {wait_cnt}); + static const Op& ptx_cp_async_wait_group_op = Op::Get("tirx.ptx_cp_async_wait_group"); + auto wait_group = Call(DataType::Void(), ptx_cp_async_wait_group_op, {wait_cnt}); this->PrintIndent(); this->VisitExpr(wait_group, this->stream); this->stream << ";\n"; @@ -2121,10 +2141,16 @@ ffi::Module BuildCUDA(IRModule mod, Target target) { ffi::Bytes(code.data(), code.size()), ffi::String("cuda"), ExtractFuncInfo(mod), source_map); } -TVM_FFI_STATIC_INIT_BLOCK() { +void RegisterCudaCodegen() { + static bool registered = false; + if (registered) return; + registered = true; + namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.cuda", BuildCUDA); } +TVM_FFI_STATIC_INIT_BLOCK() { RegisterCudaCodegen(); } + } // namespace codegen } // namespace tvm diff --git a/src/target/cuda/codegen_cuda.h b/src/backend/cuda/codegen/codegen_cuda.h similarity index 99% rename from src/target/cuda/codegen_cuda.h rename to src/backend/cuda/codegen/codegen_cuda.h index 91d640ee5d78..92ca3cab34a4 100644 --- a/src/target/cuda/codegen_cuda.h +++ b/src/backend/cuda/codegen/codegen_cuda.h @@ -31,7 +31,7 @@ #include #include -#include "../source/codegen_c.h" +#include "../../../target/source/codegen_c.h" namespace tvm { namespace codegen { diff --git a/src/target/cuda/cuda_fallback_module.cc b/src/backend/cuda/codegen/cuda_fallback_module.cc similarity index 99% rename from src/target/cuda/cuda_fallback_module.cc rename to src/backend/cuda/codegen/cuda_fallback_module.cc index a0da7fea1eb5..bdf60f362aa6 100644 --- a/src/target/cuda/cuda_fallback_module.cc +++ b/src/backend/cuda/codegen/cuda_fallback_module.cc @@ -33,7 +33,7 @@ #include #include -#include "../../support/bytes_io.h" +#include "../../../support/bytes_io.h" namespace tvm { namespace target { diff --git a/src/target/cuda/cuda_fallback_module.h b/src/backend/cuda/codegen/cuda_fallback_module.h similarity index 98% rename from src/target/cuda/cuda_fallback_module.h rename to src/backend/cuda/codegen/cuda_fallback_module.h index f328e09c8bdb..7b4bce77d47d 100644 --- a/src/target/cuda/cuda_fallback_module.h +++ b/src/backend/cuda/codegen/cuda_fallback_module.h @@ -38,8 +38,8 @@ #include -#include "../../runtime/metadata.h" -#include "../../support/env.h" +#include "../../../runtime/metadata.h" +#include "../../../support/env.h" namespace tvm { namespace target { diff --git a/src/target/cuda/intrin_rule_cuda.cc b/src/backend/cuda/codegen/intrin_rule_cuda.cc similarity index 89% rename from src/target/cuda/intrin_rule_cuda.cc rename to src/backend/cuda/codegen/intrin_rule_cuda.cc index a9aadf1aeed8..6a799aea9458 100644 --- a/src/target/cuda/intrin_rule_cuda.cc +++ b/src/backend/cuda/codegen/intrin_rule_cuda.cc @@ -21,10 +21,11 @@ * \file intrin_rule_cuda.cc * \brief CUDA intrinsic rules. */ +#include #include #include -#include "../intrin_rule.h" +#include "../../../target/intrin_rule.h" namespace tvm { namespace codegen { @@ -127,21 +128,26 @@ struct CUDAPopcount { struct CUDAWarpIntrinsic { const Op operator()(DataType t, const Op& orig_op) const { if (orig_op.same_as(builtin::tvm_warp_shuffle())) { - return Op::Get("tirx.cuda.__shfl_sync"); + static const Op& cuda_shfl_sync_op = Op::Get("tirx.cuda.__shfl_sync"); + return cuda_shfl_sync_op; } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { - return Op::Get("tirx.cuda.__shfl_up_sync"); + static const Op& cuda_shfl_up_sync_op = Op::Get("tirx.cuda.__shfl_up_sync"); + return cuda_shfl_up_sync_op; } else if (orig_op.same_as(builtin::tvm_warp_shuffle_down())) { - return Op::Get("tirx.cuda.__shfl_down_sync"); + static const Op& cuda_shfl_down_sync_op = Op::Get("tirx.cuda.__shfl_down_sync"); + return cuda_shfl_down_sync_op; } else { TVM_FFI_ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_xor())); - return Op::Get("tirx.cuda.__shfl_xor_sync"); + static const Op& cuda_shfl_xor_sync_op = Op::Get("tirx.cuda.__shfl_xor_sync"); + return cuda_shfl_xor_sync_op; } } }; static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr& e) { const CallNode* call = e.as(); - return Call(call->dtype, Op::Get("tirx.cuda.__activemask"), call->args); + static const Op& cuda_active_mask_op = Op::Get("tirx.cuda.__activemask"); + return Call(call->dtype, cuda_active_mask_op, call->args); } template @@ -153,6 +159,8 @@ static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) { return Call(call->dtype, T()(call->dtype, Downcast(call->op)), cuda_args); } +void RegisterCudaIntrinRules() { + // clang-format off TVM_REGISTER_OP("tirx.clz") .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); @@ -203,7 +211,8 @@ TVM_REGISTER_OP("tirx.log10") TVM_REGISTER_OP("tirx.tan") // Now the fast math version of tan and the default version of tan are same. - .set_attr("cuda.fastmath.FLowerIntrinsic", DispatchPureExtern) + .set_attr("cuda.fastmath.FLowerIntrinsic", + DispatchPureExtern) .set_attr("cuda.FLowerIntrinsic", DispatchPureExtern); TVM_REGISTER_OP("tirx.cos") @@ -267,7 +276,8 @@ TVM_REGISTER_OP("tirx.cuda.__shfl_sync") .add_argument("lane", "Expr", "The source thread id.") .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) - .set_attr("TDeviceIntrinsicNamespace", ffi::String("cuda"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("cuda"), + 10) .set_attr("TScriptPrinterName", ffi::String("cuda.__shfl_sync"), 10) .set_attr("TGlobalSymbol", "__shfl_sync") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) @@ -280,7 +290,8 @@ TVM_REGISTER_OP("tirx.cuda.__shfl_up_sync") .add_argument("delta", "Expr", "The source lane id offset to be added.") .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) - .set_attr("TDeviceIntrinsicNamespace", ffi::String("cuda"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("cuda"), + 10) .set_attr("TScriptPrinterName", ffi::String("cuda.__shfl_up_sync"), 10) .set_attr("TGlobalSymbol", "__shfl_up_sync") @@ -294,9 +305,10 @@ TVM_REGISTER_OP("tirx.cuda.__shfl_down_sync") .add_argument("delta", "Expr", "The source lane id offset to be subtracted.") .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) - .set_attr("TDeviceIntrinsicNamespace", ffi::String("cuda"), 10) - .set_attr("TScriptPrinterName", ffi::String("cuda.__shfl_down_sync"), - 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("cuda"), + 10) + .set_attr("TScriptPrinterName", + ffi::String("cuda.__shfl_down_sync"), 10) .set_attr("TGlobalSymbol", "__shfl_down_sync") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) .set_attr("cuda.need_warp_shuffle", true); @@ -308,7 +320,8 @@ TVM_REGISTER_OP("tirx.cuda.__shfl_xor_sync") .add_argument("lane_mask", "Expr", "The lane mask.") .add_argument("width", "Expr", "The warp thread width, must be a power of 2.") .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) - .set_attr("TDeviceIntrinsicNamespace", ffi::String("cuda"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("cuda"), + 10) .set_attr("TScriptPrinterName", ffi::String("cuda.__shfl_xor_sync"), 10) .set_attr("TGlobalSymbol", "__shfl_xor_sync") @@ -318,11 +331,17 @@ TVM_REGISTER_OP("tirx.cuda.__shfl_xor_sync") TVM_REGISTER_OP("tirx.cuda.__activemask") .set_num_inputs(0) .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) - .set_attr("TDeviceIntrinsicNamespace", ffi::String("cuda"), 10) - .set_attr("TScriptPrinterName", ffi::String("cuda.__activemask"), 10) + .set_attr("TDeviceIntrinsicNamespace", ffi::String("cuda"), + 10) + .set_attr("TScriptPrinterName", ffi::String("cuda.__activemask"), + 10) .set_attr("TGlobalSymbol", "__activemask") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kPure)) .set_attr("cuda.need_warp_shuffle", true); + // clang-format on +} + +TVM_FFI_STATIC_INIT_BLOCK() { RegisterCudaIntrinRules(); } } // namespace intrin } // namespace codegen diff --git a/src/target/cuda/literal/cuda_half_t.h b/src/backend/cuda/codegen/literal/cuda_half_t.h similarity index 100% rename from src/target/cuda/literal/cuda_half_t.h rename to src/backend/cuda/codegen/literal/cuda_half_t.h diff --git a/src/target/cuda/literal/cuda_int8_t.h b/src/backend/cuda/codegen/literal/cuda_int8_t.h similarity index 100% rename from src/target/cuda/literal/cuda_int8_t.h rename to src/backend/cuda/codegen/literal/cuda_int8_t.h diff --git a/src/target/cuda/llvm/codegen_nvptx.cc b/src/backend/cuda/codegen/llvm/codegen_nvptx.cc similarity index 97% rename from src/target/cuda/llvm/codegen_nvptx.cc rename to src/backend/cuda/codegen/llvm/codegen_nvptx.cc index 64dea2991539..e523e2b22aab 100644 --- a/src/target/cuda/llvm/codegen_nvptx.cc +++ b/src/backend/cuda/codegen/llvm/codegen_nvptx.cc @@ -52,9 +52,9 @@ #include #include -#include "../../build_common.h" -#include "../../llvm/codegen_llvm.h" -#include "../../llvm/llvm_instance.h" +#include "../../../../target/build_common.h" +#include "../../../../target/llvm/codegen_llvm.h" +#include "../../../../target/llvm/llvm_instance.h" #include "../cuda_fallback_module.h" namespace tvm { @@ -354,7 +354,11 @@ ffi::Module BuildNVPTX(IRModule mod, Target target) { ffi::String("ptx"), ExtractFuncInfo(mod), source_map); } -TVM_FFI_STATIC_INIT_BLOCK() { +void RegisterNVPTXCodegen() { + static bool registered = false; + if (registered) return; + registered = true; + namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("target.build.nvptx", BuildNVPTX) @@ -363,6 +367,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { }); } +TVM_FFI_STATIC_INIT_BLOCK() { RegisterNVPTXCodegen(); } + } // namespace codegen } // namespace tvm diff --git a/src/target/cuda/ptx.cc b/src/backend/cuda/codegen/ptx.cc similarity index 97% rename from src/target/cuda/ptx.cc rename to src/backend/cuda/codegen/ptx.cc index 66a072e2099f..0d7f08b0d80a 100644 --- a/src/target/cuda/ptx.cc +++ b/src/backend/cuda/codegen/ptx.cc @@ -23,13 +23,15 @@ #include "ptx.h" +#include + #include #include #include #include #include -#include "../../support/utils.h" +#include "../../../support/utils.h" namespace tvm { namespace codegen { @@ -149,25 +151,27 @@ inline DataType DTypeFromString(const std::string str) { } } -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "tirx.intrinsics.cuda.PTXDTypeFromString", - [](const std::string& str) -> int { return static_cast(DTypeFromString(str)); }); -} - /*! * \brief Get the string representation of given PTX data type. */ inline std::string DTypeToString(DataType dtype) { return dtype_str[static_cast(dtype)]; } -TVM_FFI_STATIC_INIT_BLOCK() { +void RegisterCudaPTXHelpers() { + static bool registered = false; + if (registered) return; + registered = true; + namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def( - "tirx.intrinsics.cuda.PTXDTypeToString", - [](const int dtype) -> std::string { return DTypeToString(static_cast(dtype)); }); + refl::GlobalDef() + .def("tirx.intrinsics.cuda.PTXDTypeFromString", + [](const std::string& str) -> int { return static_cast(DTypeFromString(str)); }) + .def("tirx.intrinsics.cuda.PTXDTypeToString", [](const int dtype) -> std::string { + return DTypeToString(static_cast(dtype)); + }); } +TVM_FFI_STATIC_INIT_BLOCK() { RegisterCudaPTXHelpers(); } + /*! * \brief Get the number of bits of given PTX data type. */ diff --git a/src/target/cuda/ptx.h b/src/backend/cuda/codegen/ptx.h similarity index 100% rename from src/target/cuda/ptx.h rename to src/backend/cuda/codegen/ptx.h diff --git a/src/backend/cuda/codegen/register.cc b/src/backend/cuda/codegen/register.cc new file mode 100644 index 000000000000..d1fa868183c2 --- /dev/null +++ b/src/backend/cuda/codegen/register.cc @@ -0,0 +1,143 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file register.cc + * \brief CUDA compiler backend static registration. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace tvm { + +namespace backend { +namespace cuda { + +bool DetectDeviceFlag(Device device, runtime::DeviceAttrKind flag, ffi::Any* val) { + using runtime::DeviceAPI; + DeviceAPI* api = DeviceAPI::Get(device, true); + if (api == nullptr) { + return false; + } + api->GetAttr(device, runtime::kExist, val); + int exists = val->cast(); + if (!exists) { + return false; + } + DeviceAPI::Get(device)->GetAttr(device, flag, val); + return true; +} + +void CheckOrSetAttr(ffi::Map* attrs, const ffi::String& name, + const ffi::String& value) { + auto iter = attrs->find(name); + if (iter == attrs->end()) { + attrs->Set(name, value); + } else { + auto str = (*iter).second.try_cast(); + TVM_FFI_CHECK(str && str.value() == value, ValueError) + << "Expects \"" << name << "\" to be \"" << value << "\", but gets: " << (*iter).second; + } +} + +bool StartsWith(const ffi::String& str, const char* prefix) { + return std::string(str).rfind(prefix, 0) == 0; +} + +ffi::Map UpdateCUDAAttrs(ffi::Map target) { + if (target.count("arch")) { + ffi::String archStr = Downcast(target.at("arch")); + TVM_FFI_CHECK(StartsWith(archStr, "sm_"), ValueError) + << "CUDA target gets an invalid CUDA arch: -arch=" << archStr; + } else { + int archInt; + ffi::Any version; + if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) { + LOG(WARNING) << "Unable to detect CUDA version, default to \"-arch=sm_50\" instead"; + archInt = 50; + } else { + archInt = std::stod(version.cast()) * 10 + 0.1; + } + if (archInt >= 90) { + target.Set("arch", ffi::String("sm_") + std::to_string(archInt) + "a"); + } else { + target.Set("arch", ffi::String("sm_") + std::to_string(archInt)); + } + } + return target; +} + +ffi::Map UpdateNVPTXAttrs(ffi::Map target) { + CheckOrSetAttr(&target, "mtriple", "nvptx64-nvidia-cuda"); + if (target.count("mcpu")) { + ffi::String mcpu = Downcast(target.at("mcpu")); + TVM_FFI_CHECK(StartsWith(mcpu, "sm_"), ValueError) + << "NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu; + } else { + int arch; + ffi::Any version; + if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) { + LOG(WARNING) << "Unable to detect CUDA version, default to \"-mcpu=sm_50\" instead"; + arch = 50; + } else { + arch = std::stod(version.cast()) * 10 + 0.1; + } + target.Set("mcpu", ffi::String("sm_") + std::to_string(arch)); + } + return target; +} + +void RegisterTargetKinds() { + namespace refl = tvm::ffi::reflection; + + TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) + .add_attr_option("mcpu") + .add_attr_option("arch") + .add_attr_option("max_shared_memory_per_block") + .add_attr_option("max_threads_per_block") + .add_attr_option("thread_warp_size", refl::DefaultValue(32)) + .add_attr_option("registers_per_block") + .add_attr_option("l2_cache_size_bytes") + .add_attr_option("max_num_threads", + refl::DefaultValue(1024)) // TODO(@zxybazh): deprecate it + .set_default_keys({"cuda", "gpu"}) + .set_target_canonicalizer(UpdateCUDAAttrs); + + TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA) + .add_attr_option("mcpu") + .add_attr_option("mtriple") + .add_attr_option("max_num_threads", refl::DefaultValue(1024)) + .add_attr_option("thread_warp_size", refl::DefaultValue(32)) + .set_default_keys({"cuda", "gpu"}) + .set_target_canonicalizer(UpdateNVPTXAttrs); +} + +} // namespace cuda +} // namespace backend +} // namespace tvm + +TVM_FFI_STATIC_INIT_BLOCK() { tvm::backend::cuda::RegisterTargetKinds(); } diff --git a/src/backend/cuda/op/register.cc b/src/backend/cuda/op/register.cc new file mode 100644 index 000000000000..51a7edbfcf4e --- /dev/null +++ b/src/backend/cuda/op/register.cc @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file register.cc + * \brief CUDA backend op static registration. + */ +#include +#include + +namespace tvm { +namespace tirx { +namespace builtin { +void RegisterCudaTargetBuiltins(); +} // namespace builtin +} // namespace tirx +} // namespace tvm + +TVM_FFI_STATIC_INIT_BLOCK() { tvm::tirx::builtin::RegisterCudaTargetBuiltins(); } diff --git a/src/tirx/op/target_builtin/cuda.cc b/src/backend/cuda/op/target_builtin.cc similarity index 94% rename from src/tirx/op/target_builtin/cuda.cc rename to src/backend/cuda/op/target_builtin.cc index 91a84dbda32f..9ebaaf590343 100644 --- a/src/tirx/op/target_builtin/cuda.cc +++ b/src/backend/cuda/op/target_builtin.cc @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -19,7 +18,7 @@ */ /*! - * \file tir/op/target_builtin/cuda.cc + * \file backend/cuda/op/target_builtin.cc * * builtin intrinsic operators specific to CUDA target. */ @@ -33,15 +32,25 @@ namespace tvm { namespace tirx { namespace builtin { -#define TIRX_DEFINE_BUILTIN_FUNC(OpName) \ - const Op& OpName() { \ - static const Op& op = Op::Get("tirx." #OpName); \ - return op; \ - } \ - TVM_TIRX_REGISTER_OP(#OpName) +#define TIRX_DEFINE_BUILTIN_FUNC(OpName) \ + OpRegEntry::RegisterOrGet("tirx." #OpName) \ + .set_name() \ + .set_attr("TScriptPrinterName", ffi::String(#OpName), 1) \ + .set_attr("TIRxOpCategory", ffi::String("builtin"), /*plevel=*/1) + +namespace { +void RegisterDeviceIntrinsicAliases(); +} + +void RegisterCudaTargetBuiltins() { + // clang-format off +static bool registered = false; +if (registered) return; +registered = true; TIRX_DEFINE_BUILTIN_FUNC(tvm_load_matrix_sync) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kReadState)); + .set_attr("TCallEffectKind", + static_cast(CallEffectKind::kReadState)); TIRX_DEFINE_BUILTIN_FUNC(tvm_mma_sync) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); @@ -81,12 +90,8 @@ TIRX_DEFINE_BUILTIN_FUNC(mma_store_legacy) TIRX_DEFINE_BUILTIN_FUNC(mma_fill_legacy) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); -const Op& ptx_ldg32() { - static const Op& op = Op::Get("tirx.ptx.ldg32"); - return op; -} - -TVM_REGISTER_OP("tirx.ptx.ldg32") +OpRegEntry::RegisterOrGet("tirx.ptx.ldg32") + .set_name() .set_num_inputs(4) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kPure)) .set_attr("TScriptPrinterName", ffi::String("ptx.ldg32"), 20) @@ -184,19 +189,16 @@ TIRX_DEFINE_BUILTIN_FUNC(ptx_elect_sync) TIRX_DEFINE_BUILTIN_FUNC(ptx_fence_mbarrier_init) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); -const Op& ptx_fetch_register() { - static const Op& op = Op::Get("tirx.ptx.fetch_register"); - return op; -} - -TVM_REGISTER_OP("tirx.ptx.fetch_register") +OpRegEntry::RegisterOrGet("tirx.ptx.fetch_register") + .set_name() .set_num_inputs(-1) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kPure)) .set_attr("TIRxOpCategory", ffi::String("device_intrin")) .set_attr("TDeviceIntrinsicNamespace", ffi::String("ptx")) .set_attr("TScriptPrinterName", ffi::String("ptx.fetch_register")); -TVM_REGISTER_OP("tirx.ptx_fetch_register") +OpRegEntry::RegisterOrGet("tirx.ptx_fetch_register") + .set_name() .set_num_inputs(-1) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kPure)) .set_attr("TIRxOpCategory", ffi::String("device_intrin")) @@ -314,6 +316,18 @@ TIRX_DEFINE_BUILTIN_FUNC(ptx_map_shared_rank) TIRX_DEFINE_BUILTIN_FUNC(cuda_func_call) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); +TIRX_DEFINE_BUILTIN_FUNC(timer_init_cuda) + .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); + +TIRX_DEFINE_BUILTIN_FUNC(timer_start_cuda) + .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); + +TIRX_DEFINE_BUILTIN_FUNC(timer_end_cuda) + .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); + +TIRX_DEFINE_BUILTIN_FUNC(timer_finalize_cuda) + .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); + TIRX_DEFINE_BUILTIN_FUNC(nvshmem_my_pe) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); @@ -362,6 +376,10 @@ TIRX_DEFINE_BUILTIN_FUNC(nvshmem_fence) TIRX_DEFINE_BUILTIN_FUNC(nvshmem_barrier_all) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); +RegisterDeviceIntrinsicAliases(); + // clang-format on +} + namespace { struct DeviceIntrinsicRegistration { @@ -567,17 +585,18 @@ const DeviceIntrinsicRegistration kDeviceIntrinsics[] = { TIRX_DEVICE_INTRIN_ALIAS(ptx_wgmma_wait_group, ptx, kOpaque), }; -const bool kDeviceIntrinsicAliasesRegistered = []() { +void RegisterDeviceIntrinsicAliases() { for (const auto& reg : kDeviceIntrinsics) { RegisterDeviceIntrinsic(reg); } - return true; -}(); +} #undef TIRX_DEVICE_INTRIN_ALIAS } // namespace +#undef TIRX_DEFINE_BUILTIN_FUNC + } // namespace builtin } // namespace tirx } // namespace tvm diff --git a/src/runtime/cuda/cuda_common.h b/src/backend/cuda/runtime/cuda_common.h similarity index 98% rename from src/runtime/cuda/cuda_common.h rename to src/backend/cuda/runtime/cuda_common.h index 7fe2e0d1672b..183a2e870229 100644 --- a/src/runtime/cuda/cuda_common.h +++ b/src/backend/cuda/runtime/cuda_common.h @@ -29,7 +29,7 @@ #include -#include "../workspace_pool.h" +#include "../../../runtime/workspace_pool.h" namespace tvm { namespace runtime { diff --git a/src/runtime/cuda/cuda_device_api.cc b/src/backend/cuda/runtime/cuda_device_api.cc similarity index 100% rename from src/runtime/cuda/cuda_device_api.cc rename to src/backend/cuda/runtime/cuda_device_api.cc diff --git a/src/runtime/cuda/cuda_module.cc b/src/backend/cuda/runtime/cuda_module.cc similarity index 98% rename from src/runtime/cuda/cuda_module.cc rename to src/backend/cuda/runtime/cuda_module.cc index 3f182afb8245..8ea734da6906 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/backend/cuda/runtime/cuda_module.cc @@ -37,10 +37,10 @@ #include #include -#include "../../support/bytes_io.h" -#include "../metadata.h" -#include "../pack_args.h" -#include "../thread_storage_scope.h" +#include "../../../runtime/metadata.h" +#include "../../../runtime/pack_args.h" +#include "../../../runtime/thread_storage_scope.h" +#include "../../../support/bytes_io.h" #include "cuda_common.h" namespace tvm { diff --git a/src/runtime/cuda/l2_cache_flush.cc b/src/backend/cuda/runtime/l2_cache_flush.cc similarity index 96% rename from src/runtime/cuda/l2_cache_flush.cc rename to src/backend/cuda/runtime/l2_cache_flush.cc index 5a1d4da0e70a..ca6987733c06 100644 --- a/src/runtime/cuda/l2_cache_flush.cc +++ b/src/backend/cuda/runtime/l2_cache_flush.cc @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -#include "../../../3rdparty/nvbench/l2_cache_flush.h" +#include "../../../../3rdparty/nvbench/l2_cache_flush.h" #include #include diff --git a/src/runtime/vm/cuda/cuda_graph_builtin.cc b/src/backend/cuda/runtime/vm/cuda_graph_builtin.cc similarity index 99% rename from src/runtime/vm/cuda/cuda_graph_builtin.cc rename to src/backend/cuda/runtime/vm/cuda_graph_builtin.cc index f9438bb487e6..6d60042b19b7 100644 --- a/src/runtime/vm/cuda/cuda_graph_builtin.cc +++ b/src/backend/cuda/runtime/vm/cuda_graph_builtin.cc @@ -28,8 +28,8 @@ #include #include -#include "../../../support/utils.h" -#include "../../cuda/cuda_common.h" +#include "../../../../support/utils.h" +#include "../cuda_common.h" namespace tvm { namespace runtime { namespace vm { diff --git a/src/target/hexagon/hexagon_fallback_module.cc b/src/backend/hexagon/codegen/hexagon_fallback_module.cc similarity index 97% rename from src/target/hexagon/hexagon_fallback_module.cc rename to src/backend/hexagon/codegen/hexagon_fallback_module.cc index 3b82291fdd3d..5d87de55ee96 100644 --- a/src/target/hexagon/hexagon_fallback_module.cc +++ b/src/backend/hexagon/codegen/hexagon_fallback_module.cc @@ -33,7 +33,7 @@ #include #include -#include "../../support/bytes_io.h" +#include "../../../support/bytes_io.h" namespace tvm { namespace target { @@ -64,7 +64,7 @@ class HexagonFallbackModuleNode : public ffi::ModuleObj { ffi::Bytes SaveToBytes() const final { // NOTE: serialization format MUST remain byte-identical to - // HexagonModuleNode::SaveToBytes in src/runtime/hexagon/hexagon_module.cc + // HexagonModuleNode::SaveToBytes in src/backend/hexagon/runtime/hexagon_module.cc // (the source of truth). Both produce a kind="hexagon" artifact that // the loader (ffi.Module.load_from_bytes.hexagon, registered only when // USE_HEXAGON=ON) deserializes. If the real impl's format changes, diff --git a/src/target/hexagon/hexagon_fallback_module.h b/src/backend/hexagon/codegen/hexagon_fallback_module.h similarity index 98% rename from src/target/hexagon/hexagon_fallback_module.h rename to src/backend/hexagon/codegen/hexagon_fallback_module.h index 3a42ac5b8bd1..f2d134b88fcf 100644 --- a/src/target/hexagon/hexagon_fallback_module.h +++ b/src/backend/hexagon/codegen/hexagon_fallback_module.h @@ -39,8 +39,8 @@ #include -#include "../../runtime/metadata.h" -#include "../../support/env.h" +#include "../../../runtime/metadata.h" +#include "../../../support/env.h" namespace tvm { namespace target { diff --git a/src/target/hexagon/llvm/codegen_hexagon.cc b/src/backend/hexagon/codegen/llvm/codegen_hexagon.cc similarity index 98% rename from src/target/hexagon/llvm/codegen_hexagon.cc rename to src/backend/hexagon/codegen/llvm/codegen_hexagon.cc index a5503d209ba7..184aaedcca90 100644 --- a/src/target/hexagon/llvm/codegen_hexagon.cc +++ b/src/backend/hexagon/codegen/llvm/codegen_hexagon.cc @@ -56,11 +56,11 @@ #include #include -#include "../../../runtime/file_utils.h" -#include "../../../runtime/metadata.h" -#include "../../build_common.h" -#include "../../llvm/codegen_cpu.h" -#include "../../llvm/llvm_instance.h" +#include "../../../../runtime/file_utils.h" +#include "../../../../runtime/metadata.h" +#include "../../../../target/build_common.h" +#include "../../../../target/llvm/codegen_cpu.h" +#include "../../../../target/llvm/llvm_instance.h" #include "../hexagon_fallback_module.h" namespace tvm { @@ -589,7 +589,11 @@ ffi::Module BuildHexagon(IRModule mod, Target target) { ExtractFuncInfo(mod), std::move(source)); } -TVM_FFI_STATIC_INIT_BLOCK() { +void RegisterHexagonCodegen() { + static bool registered = false; + if (registered) return; + registered = true; + namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("target.build.hexagon", BuildHexagon) diff --git a/src/target/hexagon/llvm/intrin_rule_hexagon.cc b/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc similarity index 89% rename from src/target/hexagon/llvm/intrin_rule_hexagon.cc rename to src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc index 23a4e6a52a14..0a4ca893b631 100644 --- a/src/target/hexagon/llvm/intrin_rule_hexagon.cc +++ b/src/backend/hexagon/codegen/llvm/intrin_rule_hexagon.cc @@ -23,10 +23,9 @@ #include #include -#include "../../llvm/intrin_rule_llvm.h" +#include "../../../../target/llvm/intrin_rule_llvm.h" -#define TVM_REGISTER_QHL_OP_FP16(INTRIN_FUNC, WRAPPER_FUNC, NUM_SIGN) \ - std::string tvm_qhl_ahf_##INTRIN_FUNC = WRAPPER_FUNC; \ +#define TVM_REGISTER_QHL_OP_FP16(INTRIN_FUNC, NUM_SIGN) \ TVM_REGISTER_OP("tirx." #INTRIN_FUNC) \ .set_attr( \ "hexagon.FLowerIntrinsic", \ @@ -38,6 +37,14 @@ namespace codegen { namespace llvm { using tirx::FLowerIntrinsic; +std::string tvm_qhl_ahf_ceil = "tvm_vect_qhmath_hvx_ceil_ahf"; +std::string tvm_qhl_ahf_cos = "tvm_vect_qhmath_hvx_cos_ahf"; +std::string tvm_qhl_ahf_exp = "tvm_vect_qhmath_hvx_exp_ahf"; +std::string tvm_qhl_ahf_floor = "tvm_vect_qhmath_hvx_floor_ahf"; +std::string tvm_qhl_ahf_sin = "tvm_vect_qhmath_hvx_sin_ahf"; +std::string tvm_qhl_ahf_pow = "tvm_vect_qhmath_hvx_pow_ahf"; +std::string tvm_qhl_ahf_sqrt = "tvm_vect_qhmath_hvx_sqrt_ahf"; + inline PrimExpr TVMExternCall(const tirx::CallNode* call, const std::string& fname) { ffi::Array new_args = {tirx::StringImm(fname)}; for (PrimExpr arg : call->args) { @@ -75,6 +82,12 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { return tirx::Call(call->dtype, tirx::builtin::call_llvm_pure_intrin(), new_args); } +void RegisterHexagonIntrinRules() { + static bool registered = false; + if (registered) return; + registered = true; + + // clang-format off TVM_REGISTER_OP("tirx.fma") .set_attr("hexagon.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); @@ -199,19 +212,21 @@ TVM_REGISTER_OP("tirx.sigmoid") return one / (one + exp(-x)); }); -TVM_REGISTER_QHL_OP_FP16(ceil, "tvm_vect_qhmath_hvx_ceil_ahf", 1) +TVM_REGISTER_QHL_OP_FP16(ceil, 1) -TVM_REGISTER_QHL_OP_FP16(cos, "tvm_vect_qhmath_hvx_cos_ahf", 1) +TVM_REGISTER_QHL_OP_FP16(cos, 1) -TVM_REGISTER_QHL_OP_FP16(exp, "tvm_vect_qhmath_hvx_exp_ahf", 1) +TVM_REGISTER_QHL_OP_FP16(exp, 1) -TVM_REGISTER_QHL_OP_FP16(floor, "tvm_vect_qhmath_hvx_floor_ahf", 1) +TVM_REGISTER_QHL_OP_FP16(floor, 1) -TVM_REGISTER_QHL_OP_FP16(sin, "tvm_vect_qhmath_hvx_sin_ahf", 1) +TVM_REGISTER_QHL_OP_FP16(sin, 1) -TVM_REGISTER_QHL_OP_FP16(pow, "tvm_vect_qhmath_hvx_pow_ahf", 2) +TVM_REGISTER_QHL_OP_FP16(pow, 2) -TVM_REGISTER_QHL_OP_FP16(sqrt, "tvm_vect_qhmath_hvx_sqrt_ahf", 1) +TVM_REGISTER_QHL_OP_FP16(sqrt, 1) + // clang-format on +} } // namespace llvm } // namespace codegen diff --git a/src/backend/hexagon/codegen/register.cc b/src/backend/hexagon/codegen/register.cc new file mode 100644 index 000000000000..d6576bfee74c --- /dev/null +++ b/src/backend/hexagon/codegen/register.cc @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file register.cc + * \brief Hexagon compiler backend static registration. + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace backend { +namespace hexagon { + +void RegisterTargetKind() { + TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) + .add_attr_option>("mattr") + .add_attr_option("mcpu") + .add_attr_option("mtriple") + .add_attr_option>("llvm-options") + .add_attr_option("num-cores") + .add_attr_option("vtcm-capacity") + .set_default_keys({"hexagon", "cpu"}); +} + +} // namespace hexagon +} // namespace backend + +#ifdef TVM_LLVM_VERSION +namespace codegen { +void RegisterHexagonCodegen(); +namespace llvm { +void RegisterHexagonIntrinRules(); +} // namespace llvm +} // namespace codegen +#endif +} // namespace tvm + +TVM_FFI_STATIC_INIT_BLOCK() { + tvm::backend::hexagon::RegisterTargetKind(); +#ifdef TVM_LLVM_VERSION + tvm::codegen::llvm::RegisterHexagonIntrinRules(); + tvm::codegen::RegisterHexagonCodegen(); +#endif +} diff --git a/src/runtime/hexagon/README.md b/src/backend/hexagon/runtime/README.md similarity index 100% rename from src/runtime/hexagon/README.md rename to src/backend/hexagon/runtime/README.md diff --git a/src/runtime/hexagon/hexagon_buffer.cc b/src/backend/hexagon/runtime/hexagon_buffer.cc similarity index 100% rename from src/runtime/hexagon/hexagon_buffer.cc rename to src/backend/hexagon/runtime/hexagon_buffer.cc diff --git a/src/runtime/hexagon/hexagon_buffer.h b/src/backend/hexagon/runtime/hexagon_buffer.h similarity index 100% rename from src/runtime/hexagon/hexagon_buffer.h rename to src/backend/hexagon/runtime/hexagon_buffer.h diff --git a/src/runtime/hexagon/hexagon_buffer_manager.h b/src/backend/hexagon/runtime/hexagon_buffer_manager.h similarity index 100% rename from src/runtime/hexagon/hexagon_buffer_manager.h rename to src/backend/hexagon/runtime/hexagon_buffer_manager.h diff --git a/src/runtime/hexagon/hexagon_common.cc b/src/backend/hexagon/runtime/hexagon_common.cc similarity index 100% rename from src/runtime/hexagon/hexagon_common.cc rename to src/backend/hexagon/runtime/hexagon_common.cc diff --git a/src/runtime/hexagon/hexagon_common.h b/src/backend/hexagon/runtime/hexagon_common.h similarity index 100% rename from src/runtime/hexagon/hexagon_common.h rename to src/backend/hexagon/runtime/hexagon_common.h diff --git a/src/runtime/hexagon/hexagon_device_api.cc b/src/backend/hexagon/runtime/hexagon_device_api.cc similarity index 99% rename from src/runtime/hexagon/hexagon_device_api.cc rename to src/backend/hexagon/runtime/hexagon_device_api.cc index ae0e0862dfc2..4cd6552bc86f 100644 --- a/src/runtime/hexagon/hexagon_device_api.cc +++ b/src/backend/hexagon/runtime/hexagon_device_api.cc @@ -31,7 +31,7 @@ #include #include -#include "../workspace_pool.h" +#include "../../../runtime/workspace_pool.h" #include "hexagon_common.h" namespace tvm { diff --git a/src/runtime/hexagon/hexagon_device_api.h b/src/backend/hexagon/runtime/hexagon_device_api.h similarity index 100% rename from src/runtime/hexagon/hexagon_device_api.h rename to src/backend/hexagon/runtime/hexagon_device_api.h diff --git a/src/runtime/hexagon/hexagon_htp.cc b/src/backend/hexagon/runtime/hexagon_htp.cc similarity index 100% rename from src/runtime/hexagon/hexagon_htp.cc rename to src/backend/hexagon/runtime/hexagon_htp.cc diff --git a/src/runtime/hexagon/hexagon_htp.h b/src/backend/hexagon/runtime/hexagon_htp.h similarity index 100% rename from src/runtime/hexagon/hexagon_htp.h rename to src/backend/hexagon/runtime/hexagon_htp.h diff --git a/src/runtime/hexagon/hexagon_hvx.cc b/src/backend/hexagon/runtime/hexagon_hvx.cc similarity index 100% rename from src/runtime/hexagon/hexagon_hvx.cc rename to src/backend/hexagon/runtime/hexagon_hvx.cc diff --git a/src/runtime/hexagon/hexagon_hvx.h b/src/backend/hexagon/runtime/hexagon_hvx.h similarity index 100% rename from src/runtime/hexagon/hexagon_hvx.h rename to src/backend/hexagon/runtime/hexagon_hvx.h diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/backend/hexagon/runtime/hexagon_module.cc similarity index 94% rename from src/runtime/hexagon/hexagon_module.cc rename to src/backend/hexagon/runtime/hexagon_module.cc index f0f801ef45ad..30d5b0d233ba 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/backend/hexagon/runtime/hexagon_module.cc @@ -23,7 +23,7 @@ * only through the FFI registry keys "ffi.Module.create.hexagon" and * "ffi.Module.load_from_bytes.hexagon". No exported header — * codegen-side construction goes through - * src/target/hexagon/hexagon_fallback_module.h. + * src/backend/hexagon/codegen/hexagon_fallback_module.h. * * This carrier holds the linked Hexagon `.so` blob in memory. The * existing HexagonModuleNode does not perform `dlopen` (Hexagon @@ -41,8 +41,8 @@ #include #include -#include "../../support/bytes_io.h" -#include "../metadata.h" +#include "../../../runtime/metadata.h" +#include "../../../support/bytes_io.h" namespace tvm { namespace runtime { @@ -92,7 +92,7 @@ class HexagonModuleNode : public ffi::ModuleObj { // and is NEVER serialized — it is lost on save/load round-trip // (matches upstream behavior; the receiver rebuilds source from code // bytes if possible). HexagonFallbackModuleNode::SaveToBytes (in - // src/target/hexagon/hexagon_fallback_module.cc) MUST mirror this + // src/backend/hexagon/codegen/hexagon_fallback_module.cc) MUST mirror this // format byte-for-byte; see one-way comment there. std::string buffer; support::BytesOutStream stream(&buffer); @@ -139,7 +139,7 @@ static ffi::Module HexagonModuleLoadFromBytes(const ffi::Bytes& bytes) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; // Registry: "ffi.Module.create.hexagon" — codegen-time Hexagon module factory. - // Used by src/target/hexagon/hexagon_fallback_module.h:HexagonModuleCreateWithFallback. + // Used by src/backend/hexagon/codegen/hexagon_fallback_module.h:HexagonModuleCreateWithFallback. // Registry: "ffi.Module.load_from_bytes.hexagon" — disk loader. Only // this (real) module registers a loader; the fallback is codegen-only. refl::GlobalDef() diff --git a/src/runtime/hexagon/hexagon_power_manager.cc b/src/backend/hexagon/runtime/hexagon_power_manager.cc similarity index 100% rename from src/runtime/hexagon/hexagon_power_manager.cc rename to src/backend/hexagon/runtime/hexagon_power_manager.cc diff --git a/src/runtime/hexagon/hexagon_power_manager.h b/src/backend/hexagon/runtime/hexagon_power_manager.h similarity index 100% rename from src/runtime/hexagon/hexagon_power_manager.h rename to src/backend/hexagon/runtime/hexagon_power_manager.h diff --git a/src/runtime/hexagon/hexagon_thread_manager.cc b/src/backend/hexagon/runtime/hexagon_thread_manager.cc similarity index 100% rename from src/runtime/hexagon/hexagon_thread_manager.cc rename to src/backend/hexagon/runtime/hexagon_thread_manager.cc diff --git a/src/runtime/hexagon/hexagon_thread_manager.h b/src/backend/hexagon/runtime/hexagon_thread_manager.h similarity index 100% rename from src/runtime/hexagon/hexagon_thread_manager.h rename to src/backend/hexagon/runtime/hexagon_thread_manager.h diff --git a/src/runtime/hexagon/hexagon_user_dma.cc b/src/backend/hexagon/runtime/hexagon_user_dma.cc similarity index 100% rename from src/runtime/hexagon/hexagon_user_dma.cc rename to src/backend/hexagon/runtime/hexagon_user_dma.cc diff --git a/src/runtime/hexagon/hexagon_user_dma.h b/src/backend/hexagon/runtime/hexagon_user_dma.h similarity index 100% rename from src/runtime/hexagon/hexagon_user_dma.h rename to src/backend/hexagon/runtime/hexagon_user_dma.h diff --git a/src/runtime/hexagon/hexagon_user_dma_descriptors.h b/src/backend/hexagon/runtime/hexagon_user_dma_descriptors.h similarity index 100% rename from src/runtime/hexagon/hexagon_user_dma_descriptors.h rename to src/backend/hexagon/runtime/hexagon_user_dma_descriptors.h diff --git a/src/runtime/hexagon/hexagon_user_dma_instructions.h b/src/backend/hexagon/runtime/hexagon_user_dma_instructions.h similarity index 100% rename from src/runtime/hexagon/hexagon_user_dma_instructions.h rename to src/backend/hexagon/runtime/hexagon_user_dma_instructions.h diff --git a/src/runtime/hexagon/hexagon_user_dma_registers.h b/src/backend/hexagon/runtime/hexagon_user_dma_registers.h similarity index 100% rename from src/runtime/hexagon/hexagon_user_dma_registers.h rename to src/backend/hexagon/runtime/hexagon_user_dma_registers.h diff --git a/src/runtime/hexagon/hexagon_vtcm_pool.cc b/src/backend/hexagon/runtime/hexagon_vtcm_pool.cc similarity index 100% rename from src/runtime/hexagon/hexagon_vtcm_pool.cc rename to src/backend/hexagon/runtime/hexagon_vtcm_pool.cc diff --git a/src/runtime/hexagon/hexagon_vtcm_pool.h b/src/backend/hexagon/runtime/hexagon_vtcm_pool.h similarity index 100% rename from src/runtime/hexagon/hexagon_vtcm_pool.h rename to src/backend/hexagon/runtime/hexagon_vtcm_pool.h diff --git a/src/runtime/hexagon/ops/conv2d.h b/src/backend/hexagon/runtime/ops/conv2d.h similarity index 100% rename from src/runtime/hexagon/ops/conv2d.h rename to src/backend/hexagon/runtime/ops/conv2d.h diff --git a/src/runtime/hexagon/ops/conv2d_fp16_hvx.cc b/src/backend/hexagon/runtime/ops/conv2d_fp16_hvx.cc similarity index 100% rename from src/runtime/hexagon/ops/conv2d_fp16_hvx.cc rename to src/backend/hexagon/runtime/ops/conv2d_fp16_hvx.cc diff --git a/src/runtime/hexagon/ops/conv2d_quant_hvx.cc b/src/backend/hexagon/runtime/ops/conv2d_quant_hvx.cc similarity index 100% rename from src/runtime/hexagon/ops/conv2d_quant_hvx.cc rename to src/backend/hexagon/runtime/ops/conv2d_quant_hvx.cc diff --git a/src/runtime/hexagon/ops/conv_utils.cc b/src/backend/hexagon/runtime/ops/conv_utils.cc similarity index 100% rename from src/runtime/hexagon/ops/conv_utils.cc rename to src/backend/hexagon/runtime/ops/conv_utils.cc diff --git a/src/runtime/hexagon/profiler/README.md b/src/backend/hexagon/runtime/profiler/README.md similarity index 100% rename from src/runtime/hexagon/profiler/README.md rename to src/backend/hexagon/runtime/profiler/README.md diff --git a/src/runtime/hexagon/profiler/lwp_handler.S b/src/backend/hexagon/runtime/profiler/lwp_handler.S similarity index 100% rename from src/runtime/hexagon/profiler/lwp_handler.S rename to src/backend/hexagon/runtime/profiler/lwp_handler.S diff --git a/src/runtime/hexagon/profiler/prof_utils.cc b/src/backend/hexagon/runtime/profiler/prof_utils.cc similarity index 100% rename from src/runtime/hexagon/profiler/prof_utils.cc rename to src/backend/hexagon/runtime/profiler/prof_utils.cc diff --git a/src/runtime/hexagon/profiler/prof_utils.h b/src/backend/hexagon/runtime/profiler/prof_utils.h similarity index 100% rename from src/runtime/hexagon/profiler/prof_utils.h rename to src/backend/hexagon/runtime/profiler/prof_utils.h diff --git a/src/runtime/hexagon/qhl/qhl_wrapper.cc b/src/backend/hexagon/runtime/qhl/qhl_wrapper.cc similarity index 100% rename from src/runtime/hexagon/qhl/qhl_wrapper.cc rename to src/backend/hexagon/runtime/qhl/qhl_wrapper.cc diff --git a/src/runtime/hexagon/ring_buffer.h b/src/backend/hexagon/runtime/ring_buffer.h similarity index 100% rename from src/runtime/hexagon/ring_buffer.h rename to src/backend/hexagon/runtime/ring_buffer.h diff --git a/src/runtime/hexagon/rpc/android/session.cc b/src/backend/hexagon/runtime/rpc/android/session.cc similarity index 96% rename from src/runtime/hexagon/rpc/android/session.cc rename to src/backend/hexagon/runtime/rpc/android/session.cc index 6052225e68f4..928a4d72709e 100644 --- a/src/runtime/hexagon/rpc/android/session.cc +++ b/src/backend/hexagon/runtime/rpc/android/session.cc @@ -35,9 +35,9 @@ extern "C" { #include -#include "../../../rpc/rpc_channel.h" -#include "../../../rpc/rpc_endpoint.h" -#include "../../../rpc/rpc_session.h" +#include "../../../../../runtime/rpc/rpc_channel.h" +#include "../../../../../runtime/rpc/rpc_endpoint.h" +#include "../../../../../runtime/rpc/rpc_session.h" #include "../hexagon_rpc.h" namespace tvm { diff --git a/src/runtime/hexagon/rpc/android_bash.sh.template b/src/backend/hexagon/runtime/rpc/android_bash.sh.template similarity index 100% rename from src/runtime/hexagon/rpc/android_bash.sh.template rename to src/backend/hexagon/runtime/rpc/android_bash.sh.template diff --git a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc b/src/backend/hexagon/runtime/rpc/hexagon/rpc_server.cc similarity index 98% rename from src/runtime/hexagon/rpc/hexagon/rpc_server.cc rename to src/backend/hexagon/runtime/rpc/hexagon/rpc_server.cc index 40dc3f34b73d..2e7255e170f2 100644 --- a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc +++ b/src/backend/hexagon/runtime/rpc/hexagon/rpc_server.cc @@ -36,9 +36,9 @@ extern "C" { #include #include -#include "../../../rpc/minrpc/minrpc_server.h" -#include "../../hexagon/hexagon_common.h" -#include "../../hexagon/hexagon_device_api.h" +#include "../../../../../runtime/rpc/minrpc/minrpc_server.h" +#include "../../hexagon_common.h" +#include "../../hexagon_device_api.h" #include "../../profiler/prof_utils.h" #include "hexagon_rpc.h" diff --git a/src/runtime/hexagon/rpc/hexagon_rpc.idl b/src/backend/hexagon/runtime/rpc/hexagon_rpc.idl similarity index 100% rename from src/runtime/hexagon/rpc/hexagon_rpc.idl rename to src/backend/hexagon/runtime/rpc/hexagon_rpc.idl diff --git a/src/runtime/hexagon/rpc/simulator/hexagon_sim_proto.h b/src/backend/hexagon/runtime/rpc/simulator/hexagon_sim_proto.h similarity index 100% rename from src/runtime/hexagon/rpc/simulator/hexagon_sim_proto.h rename to src/backend/hexagon/runtime/rpc/simulator/hexagon_sim_proto.h diff --git a/src/runtime/hexagon/rpc/simulator/rpc_server.cc b/src/backend/hexagon/runtime/rpc/simulator/rpc_server.cc similarity index 99% rename from src/runtime/hexagon/rpc/simulator/rpc_server.cc rename to src/backend/hexagon/runtime/rpc/simulator/rpc_server.cc index 61bd055dc4fe..c96ddc3fec6b 100644 --- a/src/runtime/hexagon/rpc/simulator/rpc_server.cc +++ b/src/backend/hexagon/runtime/rpc/simulator/rpc_server.cc @@ -28,7 +28,7 @@ #include #include -#include "../../../rpc/minrpc/minrpc_server.h" +#include "../../../../../runtime/rpc/minrpc/minrpc_server.h" #include "../../hexagon_common.h" #include "../../profiler/prof_utils.h" #include "hexagon_sim_proto.h" diff --git a/src/runtime/hexagon/rpc/simulator/session.cc b/src/backend/hexagon/runtime/rpc/simulator/session.cc similarity index 99% rename from src/runtime/hexagon/rpc/simulator/session.cc rename to src/backend/hexagon/runtime/rpc/simulator/session.cc index 918614afcde7..f4524d4c1780 100644 --- a/src/runtime/hexagon/rpc/simulator/session.cc +++ b/src/backend/hexagon/runtime/rpc/simulator/session.cc @@ -36,9 +36,9 @@ #include #include -#include "../../../rpc/rpc_channel.h" -#include "../../../rpc/rpc_endpoint.h" -#include "../../../rpc/rpc_session.h" +#include "../../../../../runtime/rpc/rpc_channel.h" +#include "../../../../../runtime/rpc/rpc_endpoint.h" +#include "../../../../../runtime/rpc/rpc_session.h" #include "hexagon_sim_proto.h" #define CHECKED_CALL(func, ...) \ diff --git a/src/target/metal/codegen_metal.cc b/src/backend/metal/codegen/codegen_metal.cc similarity index 95% rename from src/target/metal/codegen_metal.cc rename to src/backend/metal/codegen/codegen_metal.cc index 22f97fa9ce84..b68840f32752 100644 --- a/src/target/metal/codegen_metal.cc +++ b/src/backend/metal/codegen/codegen_metal.cc @@ -35,8 +35,8 @@ #include #include -#include "../../runtime/thread_storage_scope.h" -#include "../build_common.h" +#include "../../../runtime/thread_storage_scope.h" +#include "../../../target/build_common.h" #include "metal_fallback_module.h" namespace tvm { @@ -382,7 +382,13 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT TVM_FFI_ICHECK(col_val == 8 && row_val == 8) << "Only 8x8 matrix is supported, but got " << col_val << "x" << row_val; }; - if (op->op.same_as(builtin::make_filled_simdgroup_matrix())) { + + static const Op& make_filled_simdgroup_matrix_op = Op::Get("tirx.make_filled_simdgroup_matrix"); + static const Op& simdgroup_load_op = Op::Get("tirx.simdgroup_load"); + static const Op& simdgroup_store_op = Op::Get("tirx.simdgroup_store"); + static const Op& simdgroup_multiply_accumulate_op = Op::Get("tirx.simdgroup_multiply_accumulate"); + + if (op->op.same_as(make_filled_simdgroup_matrix_op)) { TVM_FFI_ICHECK_EQ(op->args.size(), 5); Var var = Downcast(op->args[0]); // Get the data type of the simdgroup matrix @@ -394,19 +400,19 @@ void CodeGenMetal::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT os << PrintExpr(var) << "[" << PrintExpr(op->args[1]) << "] = make_filled_simdgroup_matrix<" << dtype_str << ", " << PrintExpr(op->args[3]) << ", " << PrintExpr(op->args[4]) << ">(" << PrintExpr(op->args[2]) << ")"; - } else if (op->op.same_as(builtin::simdgroup_load())) { + } else if (op->op.same_as(simdgroup_load_op)) { TVM_FFI_ICHECK_EQ(op->args.size(), 7); f_check_simdgroup_shape(op->args[4], op->args[5]); os << "simdgroup_load(" << PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], " << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ", 0, " << PrintExpr(op->args[6]) << ")"; - } else if (op->op.same_as(builtin::simdgroup_store())) { + } else if (op->op.same_as(simdgroup_store_op)) { TVM_FFI_ICHECK_EQ(op->args.size(), 7); f_check_simdgroup_shape(op->args[4], op->args[5]); os << "simdgroup_store(" << PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], " << PrintExpr(op->args[2]) << ", " << PrintExpr(op->args[3]) << ", 0, " << PrintExpr(op->args[6]) << ")"; - } else if (op->op.same_as(builtin::simdgroup_multiply_accumulate())) { + } else if (op->op.same_as(simdgroup_multiply_accumulate_op)) { TVM_FFI_ICHECK_EQ(op->args.size(), 8); os << "simdgroup_multiply_accumulate(" // << PrintExpr(op->args[0]) << "[" << PrintExpr(op->args[1]) << "], " // @@ -491,7 +497,11 @@ ffi::Module BuildMetal(IRModule mod, Target target) { ExtractFuncInfo(mod), std::move(source)); } -TVM_FFI_STATIC_INIT_BLOCK() { +void RegisterMetalCodegen() { + static bool registered = false; + if (registered) return; + registered = true; + namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.metal", BuildMetal); } diff --git a/src/target/metal/codegen_metal.h b/src/backend/metal/codegen/codegen_metal.h similarity index 98% rename from src/target/metal/codegen_metal.h rename to src/backend/metal/codegen/codegen_metal.h index 77884660a1f0..b92608aecfa1 100644 --- a/src/target/metal/codegen_metal.h +++ b/src/backend/metal/codegen/codegen_metal.h @@ -29,7 +29,7 @@ #include #include -#include "../source/codegen_c.h" +#include "../../../target/source/codegen_c.h" namespace tvm { namespace codegen { diff --git a/src/target/metal/intrin_rule_metal.cc b/src/backend/metal/codegen/intrin_rule_metal.cc similarity index 91% rename from src/target/metal/intrin_rule_metal.cc rename to src/backend/metal/codegen/intrin_rule_metal.cc index 54417c6cdc94..f3f844ff361a 100644 --- a/src/target/metal/intrin_rule_metal.cc +++ b/src/backend/metal/codegen/intrin_rule_metal.cc @@ -23,7 +23,7 @@ */ #include -#include "../intrin_rule.h" +#include "../../../target/intrin_rule.h" namespace tvm { namespace codegen { @@ -33,12 +33,15 @@ using tirx::FLowerIntrinsic; struct MetalWarpIntrinsic { const Op operator()(DataType t, const Op& orig_op) const { if (orig_op.same_as(builtin::tvm_warp_shuffle())) { - return Op::Get("tirx.metal.simd_shuffle"); + static const Op& metal_simd_shuffle_op = Op::Get("tirx.metal.simd_shuffle"); + return metal_simd_shuffle_op; } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { - return Op::Get("tirx.metal.simd_shuffle_up"); + static const Op& metal_simd_shuffle_up_op = Op::Get("tirx.metal.simd_shuffle_up"); + return metal_simd_shuffle_up_op; } else { TVM_FFI_ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); - return Op::Get("tirx.metal.simd_shuffle_down"); + static const Op& metal_simd_shuffle_down_op = Op::Get("tirx.metal.simd_shuffle_down"); + return metal_simd_shuffle_down_op; } } }; @@ -52,6 +55,8 @@ static PrimExpr DispatchMetalShuffle(const PrimExpr& e) { return Call(call->dtype, T()(call->dtype, Downcast(call->op)), metal_args); } +void RegisterMetalIntrinRules() { + // clang-format off TVM_REGISTER_OP("tirx.clz") .set_attr("metal.FLowerIntrinsic", DispatchPureExtern); @@ -146,7 +151,8 @@ TVM_REGISTER_OP("tirx.metal.simd_shuffle") .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) .set_attr("TDeviceIntrinsicNamespace", ffi::String("metal"), 10) - .set_attr("TScriptPrinterName", ffi::String("metal.simd_shuffle"), 10) + .set_attr("TScriptPrinterName", ffi::String("metal.simd_shuffle"), + 10) .set_attr("TGlobalSymbol", "simd_shuffle") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); @@ -157,8 +163,8 @@ TVM_REGISTER_OP("tirx.metal.simd_shuffle_up") .set_attr("TIRxOpCategory", ffi::String("device_intrin"), 10) .set_attr("TDeviceIntrinsicNamespace", ffi::String("metal"), 10) - .set_attr("TScriptPrinterName", ffi::String("metal.simd_shuffle_up"), - 10) + .set_attr("TScriptPrinterName", + ffi::String("metal.simd_shuffle_up"), 10) .set_attr("TGlobalSymbol", "simd_shuffle_up") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); @@ -173,6 +179,8 @@ TVM_REGISTER_OP("tirx.metal.simd_shuffle_down") ffi::String("metal.simd_shuffle_down"), 10) .set_attr("TGlobalSymbol", "simd_shuffle_down") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); + // clang-format on +} } // namespace intrin } // namespace codegen diff --git a/src/target/metal/metal_fallback_module.cc b/src/backend/metal/codegen/metal_fallback_module.cc similarity index 99% rename from src/target/metal/metal_fallback_module.cc rename to src/backend/metal/codegen/metal_fallback_module.cc index 0954c307b4cb..77fdaaaa88ce 100644 --- a/src/target/metal/metal_fallback_module.cc +++ b/src/backend/metal/codegen/metal_fallback_module.cc @@ -33,7 +33,7 @@ #include #include -#include "../../support/bytes_io.h" +#include "../../../support/bytes_io.h" namespace tvm { namespace target { diff --git a/src/target/metal/metal_fallback_module.h b/src/backend/metal/codegen/metal_fallback_module.h similarity index 98% rename from src/target/metal/metal_fallback_module.h rename to src/backend/metal/codegen/metal_fallback_module.h index b30ebd991289..7a0e27d4e4cd 100644 --- a/src/target/metal/metal_fallback_module.h +++ b/src/backend/metal/codegen/metal_fallback_module.h @@ -38,8 +38,8 @@ #include -#include "../../runtime/metadata.h" -#include "../../support/env.h" +#include "../../../runtime/metadata.h" +#include "../../../support/env.h" namespace tvm { namespace target { diff --git a/src/backend/metal/codegen/register.cc b/src/backend/metal/codegen/register.cc new file mode 100644 index 000000000000..e90651775e57 --- /dev/null +++ b/src/backend/metal/codegen/register.cc @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file register.cc + * \brief Metal compiler backend static registration. + */ +#include +#include +#include +#include +#include + +namespace tvm { + +namespace backend { +namespace metal { + +void RegisterTargetKind() { + namespace refl = tvm::ffi::reflection; + + // Metal limits the number of kernel arguments. `max_function_args` captures that bound. + TVM_REGISTER_TARGET_KIND("metal", kDLMetal) + .add_attr_option("max_num_threads", refl::DefaultValue(256)) + .add_attr_option("max_threads_per_block", refl::DefaultValue(256)) + .add_attr_option("max_shared_memory_per_block", refl::DefaultValue(32768)) + .add_attr_option("thread_warp_size", refl::DefaultValue(16)) + .add_attr_option("max_function_args", refl::DefaultValue(31)) + .set_default_keys({"metal", "gpu"}); +} + +} // namespace metal +} // namespace backend + +namespace codegen { +void RegisterMetalCodegen(); +namespace intrin { +void RegisterMetalIntrinRules(); +} // namespace intrin +} // namespace codegen +} // namespace tvm + +TVM_FFI_STATIC_INIT_BLOCK() { + tvm::backend::metal::RegisterTargetKind(); + tvm::codegen::intrin::RegisterMetalIntrinRules(); + tvm::codegen::RegisterMetalCodegen(); +} diff --git a/src/backend/metal/op/register.cc b/src/backend/metal/op/register.cc new file mode 100644 index 000000000000..6a12ee1282c8 --- /dev/null +++ b/src/backend/metal/op/register.cc @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file register.cc + * \brief Metal backend op static registration. + */ +#include +#include + +namespace tvm { +namespace tirx { +namespace builtin { +void RegisterMetalTargetBuiltins(); +} // namespace builtin +} // namespace tirx +} // namespace tvm + +TVM_FFI_STATIC_INIT_BLOCK() { tvm::tirx::builtin::RegisterMetalTargetBuiltins(); } diff --git a/src/backend/metal/op/target_builtin.cc b/src/backend/metal/op/target_builtin.cc new file mode 100644 index 000000000000..f3bd50b7ab59 --- /dev/null +++ b/src/backend/metal/op/target_builtin.cc @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file backend/metal/op/target_builtin.cc + * + * builtin intrinsic operators specific to Metal target. + */ +#include +#include + +namespace tvm { +namespace tirx { +namespace builtin { + +#define TIRX_DEFINE_BUILTIN_FUNC(OpName) \ + OpRegEntry::RegisterOrGet("tirx." #OpName) \ + .set_name() \ + .set_attr("TScriptPrinterName", ffi::String(#OpName), 1) \ + .set_attr("TIRxOpCategory", ffi::String("builtin"), /*plevel=*/1) + +void RegisterMetalTargetBuiltins() { + // clang-format off +static bool registered = false; +if (registered) return; +registered = true; + +TIRX_DEFINE_BUILTIN_FUNC(make_filled_simdgroup_matrix) + .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); + +TIRX_DEFINE_BUILTIN_FUNC(simdgroup_load) + .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); + +TIRX_DEFINE_BUILTIN_FUNC(simdgroup_store) + .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); + +TIRX_DEFINE_BUILTIN_FUNC(simdgroup_multiply_accumulate) + .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); + // clang-format on +} + +#undef TIRX_DEFINE_BUILTIN_FUNC + +} // namespace builtin +} // namespace tirx +} // namespace tvm diff --git a/src/runtime/metal/metal_common.h b/src/backend/metal/runtime/metal_common.h similarity index 99% rename from src/runtime/metal/metal_common.h rename to src/backend/metal/runtime/metal_common.h index 184cbca05fb8..bf2f44920af4 100644 --- a/src/runtime/metal/metal_common.h +++ b/src/backend/metal/runtime/metal_common.h @@ -41,7 +41,7 @@ #include #include -#include "../workspace_pool.h" +#include "../../../runtime/workspace_pool.h" /* Macro for convenience in using AutoReleasePoolWrapper. * With this macro we can add AutoReleasePoolWrapper to our ObjC code in more diff --git a/src/runtime/metal/metal_device_api.mm b/src/backend/metal/runtime/metal_device_api.mm similarity index 100% rename from src/runtime/metal/metal_device_api.mm rename to src/backend/metal/runtime/metal_device_api.mm diff --git a/src/runtime/metal/metal_module.mm b/src/backend/metal/runtime/metal_module.mm similarity index 98% rename from src/runtime/metal/metal_module.mm rename to src/backend/metal/runtime/metal_module.mm index 782fc92235ba..aff2152a1363 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/backend/metal/runtime/metal_module.mm @@ -37,11 +37,11 @@ #include #include #include -#include "../../support/bytes_io.h" -#include "../file_utils.h" -#include "../metadata.h" -#include "../pack_args.h" -#include "../thread_storage_scope.h" +#include "../../../runtime/file_utils.h" +#include "../../../runtime/metadata.h" +#include "../../../runtime/pack_args.h" +#include "../../../runtime/thread_storage_scope.h" +#include "../../../support/bytes_io.h" #include "metal_common.h" namespace tvm { diff --git a/src/target/opencl/codegen_opencl.cc b/src/backend/opencl/codegen/codegen_opencl.cc similarity index 98% rename from src/target/opencl/codegen_opencl.cc rename to src/backend/opencl/codegen/codegen_opencl.cc index 7016f0fbbf06..5bad02e55824 100644 --- a/src/target/opencl/codegen_opencl.cc +++ b/src/backend/opencl/codegen/codegen_opencl.cc @@ -29,9 +29,9 @@ #include #include -#include "../../runtime/opencl/texture.h" -#include "../../runtime/thread_storage_scope.h" -#include "../build_common.h" +#include "../../../runtime/thread_storage_scope.h" +#include "../../../target/build_common.h" +#include "../runtime/texture.h" #include "opencl_fallback_module.h" namespace tvm { @@ -723,7 +723,11 @@ ffi::Module BuildOpenCL(IRModule mod, Target target) { ExtractFuncInfo(mod), std::move(source)); } -TVM_FFI_STATIC_INIT_BLOCK() { +void RegisterOpenCLCodegen() { + static bool registered = false; + if (registered) return; + registered = true; + namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.opencl", BuildOpenCL); } @@ -736,7 +740,11 @@ ffi::String DeviceScopeCompatibilityFromTarget(Target target, ffi::String memory return memory_scope; } -TVM_FFI_STATIC_INIT_BLOCK() { +void RegisterOpenCLDeviceScopeCompatibility() { + static bool registered = false; + if (registered) return; + registered = true; + namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("DeviceScopeCompatibility.opencl", DeviceScopeCompatibilityFromTarget); } diff --git a/src/target/opencl/codegen_opencl.h b/src/backend/opencl/codegen/codegen_opencl.h similarity index 98% rename from src/target/opencl/codegen_opencl.h rename to src/backend/opencl/codegen/codegen_opencl.h index 5f0d3e94b393..d588a18c2029 100644 --- a/src/target/opencl/codegen_opencl.h +++ b/src/backend/opencl/codegen/codegen_opencl.h @@ -29,7 +29,7 @@ #include #include -#include "../source/codegen_c.h" +#include "../../../target/source/codegen_c.h" namespace tvm { namespace codegen { diff --git a/src/target/opencl/intrin_rule_opencl.cc b/src/backend/opencl/codegen/intrin_rule_opencl.cc similarity index 93% rename from src/target/opencl/intrin_rule_opencl.cc rename to src/backend/opencl/codegen/intrin_rule_opencl.cc index 6f76af4b0e35..f0f58be84d10 100644 --- a/src/target/opencl/intrin_rule_opencl.cc +++ b/src/backend/opencl/codegen/intrin_rule_opencl.cc @@ -24,13 +24,33 @@ #include #include -#include "../intrin_rule.h" +#include "../../../target/intrin_rule.h" namespace tvm { namespace codegen { namespace intrin { using tirx::FLowerIntrinsic; +// There is no warp shuffle instruction in standard OpenCL. When shuffle is used, assume Intel's +// shuffle extension. +static PrimExpr DispatchIntelShuffle(const PrimExpr& e) { + const CallNode* call = e.as(); + TVM_FFI_ICHECK(call != nullptr); + TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size + arith::Analyzer analyzer; + TVM_FFI_ICHECK(analyzer->CanProve(call->args[3] == call->args[4])) + << "Intel warp shuffle dose not support width != warp_size"; + ffi::Array opencl_args{ + {StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; + return Call(call->dtype, builtin::call_pure_extern(), opencl_args); +} + +void RegisterOpenCLIntrinRules() { + static bool registered = false; + if (registered) return; + registered = true; + + // clang-format off TVM_REGISTER_OP("tirx.clz") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); @@ -109,22 +129,10 @@ TVM_REGISTER_OP("tirx.cos") TVM_REGISTER_OP("tirx.cosh") .set_attr("opencl.FLowerIntrinsic", DispatchPureExtern); -// There is no warp shuffle instruction in standard OpenCL -// When shuffle is used, we assume it is intel's shuffle extension -static PrimExpr DispatchIntelShuffle(const PrimExpr& e) { - const CallNode* call = e.as(); - TVM_FFI_ICHECK(call != nullptr); - TVM_FFI_ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size - arith::Analyzer analyzer; - TVM_FFI_ICHECK(analyzer->CanProve(call->args[3] == call->args[4])) - << "Intel warp shuffle dose not support width != warp_size"; - ffi::Array opencl_args{ - {StringImm("intel_sub_group_shuffle"), call->args[1], call->args[2]}}; - return Call(call->dtype, builtin::call_pure_extern(), opencl_args); -} - TVM_REGISTER_OP("tirx.tvm_warp_shuffle") .set_attr("opencl.FLowerIntrinsic", DispatchIntelShuffle); + // clang-format on +} } // namespace intrin } // namespace codegen diff --git a/src/target/opencl/opencl_fallback_module.cc b/src/backend/opencl/codegen/opencl_fallback_module.cc similarity index 97% rename from src/target/opencl/opencl_fallback_module.cc rename to src/backend/opencl/codegen/opencl_fallback_module.cc index 454b25471591..171e15581cfd 100644 --- a/src/target/opencl/opencl_fallback_module.cc +++ b/src/backend/opencl/codegen/opencl_fallback_module.cc @@ -34,7 +34,7 @@ #include #include -#include "../../support/bytes_io.h" +#include "../../../support/bytes_io.h" namespace tvm { namespace target { @@ -66,7 +66,7 @@ class OpenCLFallbackModuleNode : public ffi::ModuleObj { ffi::Bytes SaveToBytes() const final { // NOTE: serialization format MUST remain byte-identical to - // OpenCLModuleNode::SaveToBytes in src/runtime/opencl/opencl_module.cc + // OpenCLModuleNode::SaveToBytes in src/backend/opencl/runtime/opencl_module.cc // (the source of truth). Both produce a kind="opencl" artifact // that the loader (ffi.Module.load_from_bytes.opencl, registered // only when USE_OPENCL=ON) deserializes. If the real impl's diff --git a/src/target/opencl/opencl_fallback_module.h b/src/backend/opencl/codegen/opencl_fallback_module.h similarity index 98% rename from src/target/opencl/opencl_fallback_module.h rename to src/backend/opencl/codegen/opencl_fallback_module.h index ef401c75478d..326667fff89c 100644 --- a/src/target/opencl/opencl_fallback_module.h +++ b/src/backend/opencl/codegen/opencl_fallback_module.h @@ -38,8 +38,8 @@ #include -#include "../../runtime/metadata.h" -#include "../../support/env.h" +#include "../../../runtime/metadata.h" +#include "../../../support/env.h" namespace tvm { namespace target { diff --git a/src/backend/opencl/codegen/register.cc b/src/backend/opencl/codegen/register.cc new file mode 100644 index 000000000000..1b64cefeb717 --- /dev/null +++ b/src/backend/opencl/codegen/register.cc @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file register.cc + * \brief OpenCL compiler backend static registration. + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace backend { +namespace opencl { + +void RegisterTargetKind() { + namespace refl = tvm::ffi::reflection; + + TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) + .add_attr_option("max_threads_per_block", refl::DefaultValue(256)) + .add_attr_option("max_shared_memory_per_block", refl::DefaultValue(16384)) + .add_attr_option("max_num_threads", refl::DefaultValue(256)) + .add_attr_option("thread_warp_size", refl::DefaultValue(1)) + .add_attr_option("texture_spatial_limit", refl::DefaultValue(16384)) + .add_attr_option("texture_depth_limit", refl::DefaultValue(2048)) + // Qualcomm OpenCL runtimes may crash when the number of kernel arguments is too large. + .add_attr_option("max_function_args", refl::DefaultValue(128)) + .add_attr_option("image_base_address_alignment", refl::DefaultValue(64)) + .set_default_keys({"opencl", "gpu"}); +} + +} // namespace opencl +} // namespace backend + +namespace codegen { +void RegisterOpenCLCodegen(); +void RegisterOpenCLDeviceScopeCompatibility(); +namespace intrin { +void RegisterOpenCLIntrinRules(); +} // namespace intrin +} // namespace codegen +} // namespace tvm + +TVM_FFI_STATIC_INIT_BLOCK() { + tvm::backend::opencl::RegisterTargetKind(); + tvm::codegen::intrin::RegisterOpenCLIntrinRules(); + tvm::codegen::RegisterOpenCLCodegen(); + tvm::codegen::RegisterOpenCLDeviceScopeCompatibility(); +} diff --git a/src/runtime/opencl/opencl_common.h b/src/backend/opencl/runtime/opencl_common.h similarity index 99% rename from src/runtime/opencl/opencl_common.h rename to src/backend/opencl/runtime/opencl_common.h index 7b9a76dc3a8e..3b99fa166def 100644 --- a/src/runtime/opencl/opencl_common.h +++ b/src/backend/opencl/runtime/opencl_common.h @@ -70,10 +70,10 @@ #include #include -#include "../file_utils.h" -#include "../metadata.h" -#include "../pack_args.h" -#include "../thread_storage_scope.h" +#include "../../../runtime/file_utils.h" +#include "../../../runtime/metadata.h" +#include "../../../runtime/pack_args.h" +#include "../../../runtime/thread_storage_scope.h" #include "texture.h" namespace tvm { diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/backend/opencl/runtime/opencl_device_api.cc similarity index 99% rename from src/runtime/opencl/opencl_device_api.cc rename to src/backend/opencl/runtime/opencl_device_api.cc index 14823f18b3cb..eeb8e95ad543 100644 --- a/src/runtime/opencl/opencl_device_api.cc +++ b/src/backend/opencl/runtime/opencl_device_api.cc @@ -27,8 +27,8 @@ #include -#include "../../support/env.h" -#include "../memory/pooled_allocator.h" +#include "../../../runtime/memory/pooled_allocator.h" +#include "../../../support/env.h" #include "opencl_common.h" #ifdef OPENCL_ENABLE_HOST_PTR diff --git a/src/runtime/opencl/opencl_module.cc b/src/backend/opencl/runtime/opencl_module.cc similarity index 97% rename from src/runtime/opencl/opencl_module.cc rename to src/backend/opencl/runtime/opencl_module.cc index 63d721d4e30f..ff81212b7803 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/backend/opencl/runtime/opencl_module.cc @@ -21,9 +21,10 @@ * \file opencl_module.cc * \brief Plugin-only OpenCL runtime module. Built only when * USE_OPENCL=ON. No exported header — codegen-side construction - * goes through src/target/opencl/opencl_fallback_module.h:OpenCLModuleCreateWithFallback, - * which dispatches to "ffi.Module.create.opencl" registered - * below when this file is linked into the build. + * goes through + * src/backend/opencl/codegen/opencl_fallback_module.h:OpenCLModuleCreateWithFallback, which + * dispatches to "ffi.Module.create.opencl" registered below when this file is linked into the + * build. */ #include #include @@ -34,7 +35,7 @@ #include #include -#include "../../support/bytes_io.h" +#include "../../../support/bytes_io.h" #include "opencl_common.h" #include "source_utils.h" @@ -168,7 +169,7 @@ ffi::Optional OpenCLModuleNodeBase::GetFunction(const ffi::String ffi::Bytes OpenCLModuleNode::SaveToBytes() const { // NOTE: serialization format MUST remain byte-identical to // target::OpenCLFallbackModuleNode::SaveToBytes in - // src/target/opencl/opencl_fallback_module.cc. This file is the + // src/backend/opencl/codegen/opencl_fallback_module.cc. This file is the // source of truth; the fallback follows. // 3 fields only — the source map is in-memory inspection material // and is NEVER serialized (matches upstream behavior for all @@ -401,7 +402,7 @@ ffi::Module OpenCLModuleLoadFromBytes(const ffi::Bytes& bytes) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; // Registry: "ffi.Module.create.opencl" — codegen-time OpenCL module factory. - // Used by src/target/opencl/opencl_fallback_module.h:OpenCLModuleCreateWithFallback. + // Used by src/backend/opencl/codegen/opencl_fallback_module.h:OpenCLModuleCreateWithFallback. // Registry: "ffi.Module.load_from_bytes.opencl" — disk loader. Only this // (real) module registers a loader; the fallback is codegen-only. refl::GlobalDef() diff --git a/src/runtime/opencl/opencl_wrapper/README.md b/src/backend/opencl/runtime/opencl_wrapper/README.md similarity index 100% rename from src/runtime/opencl/opencl_wrapper/README.md rename to src/backend/opencl/runtime/opencl_wrapper/README.md diff --git a/src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc b/src/backend/opencl/runtime/opencl_wrapper/opencl_wrapper.cc similarity index 100% rename from src/runtime/opencl/opencl_wrapper/opencl_wrapper.cc rename to src/backend/opencl/runtime/opencl_wrapper/opencl_wrapper.cc diff --git a/src/runtime/opencl/source_utils.h b/src/backend/opencl/runtime/source_utils.h similarity index 100% rename from src/runtime/opencl/source_utils.h rename to src/backend/opencl/runtime/source_utils.h diff --git a/src/runtime/opencl/texture.h b/src/backend/opencl/runtime/texture.h similarity index 100% rename from src/runtime/opencl/texture.h rename to src/backend/opencl/runtime/texture.h diff --git a/src/target/rocm/llvm/codegen_amdgpu.cc b/src/backend/rocm/codegen/llvm/codegen_amdgpu.cc similarity index 97% rename from src/target/rocm/llvm/codegen_amdgpu.cc rename to src/backend/rocm/codegen/llvm/codegen_amdgpu.cc index 12a8aed79bd8..22ce75cddade 100644 --- a/src/target/rocm/llvm/codegen_amdgpu.cc +++ b/src/backend/rocm/codegen/llvm/codegen_amdgpu.cc @@ -49,10 +49,10 @@ #include #include -#include "../../../runtime/metadata.h" -#include "../../build_common.h" -#include "../../llvm/codegen_llvm.h" -#include "../../llvm/llvm_instance.h" +#include "../../../../runtime/metadata.h" +#include "../../../../target/build_common.h" +#include "../../../../target/llvm/codegen_llvm.h" +#include "../../../../target/llvm/llvm_instance.h" #include "../rocm_fallback_module.h" namespace tvm { @@ -321,7 +321,11 @@ ffi::Module BuildAMDGPU(IRModule mod, Target target) { ExtractFuncInfo(mod), std::move(source)); } -TVM_FFI_STATIC_INIT_BLOCK() { +void RegisterAMDGPUCodegen() { + static bool registered = false; + if (registered) return; + registered = true; + namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("target.build.rocm", BuildAMDGPU) diff --git a/src/target/rocm/llvm/intrin_rule_rocm.cc b/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc similarity index 97% rename from src/target/rocm/llvm/intrin_rule_rocm.cc rename to src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc index 8849270b03b5..8bd0497a0d59 100644 --- a/src/target/rocm/llvm/intrin_rule_rocm.cc +++ b/src/backend/rocm/codegen/llvm/intrin_rule_rocm.cc @@ -31,8 +31,8 @@ #include -#include "../../intrin_rule.h" -#include "../../llvm/intrin_rule_llvm.h" +#include "../../../../target/intrin_rule.h" +#include "../../../../target/llvm/intrin_rule_llvm.h" namespace tvm { namespace codegen { @@ -106,6 +106,8 @@ inline PrimExpr DispatchShuffle(const PrimExpr& e) { namespace llvm { using tirx::FLowerIntrinsic; +void RegisterROCMIntrinRules() { + // clang-format off // dummy because we don't have the activemask TVM_REGISTER_OP("tirx.tvm_warp_activemask") .set_attr("rocm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { @@ -208,6 +210,8 @@ TVM_REGISTER_OP("tirx.erf") // TVM_REGISTER_OP("tirx.exp10") // .set_attr("rocm.FLowerIntrinsic", // DispatchLLVMPureIntrin<::llvm::Intrinsic::exp10, 1>); + // clang-format on +} } // namespace llvm } // namespace codegen diff --git a/src/backend/rocm/codegen/register.cc b/src/backend/rocm/codegen/register.cc new file mode 100644 index 000000000000..4d0c8b738b2b --- /dev/null +++ b/src/backend/rocm/codegen/register.cc @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file register.cc + * \brief ROCm compiler backend static registration. + */ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { + +namespace backend { +namespace rocm { + +std::string ExtractStringWithPrefix(const std::string& str, const std::string& prefix) { + if (str.find(prefix) != 0) return ""; + std::size_t pos = prefix.length(); + while (pos < str.length() && (std::isdigit(str[pos]) || std::isalpha(str[pos]))) { + ++pos; + } + return str.substr(prefix.length(), pos - prefix.length()); +} + +bool DetectDeviceFlag(Device device, runtime::DeviceAttrKind flag, ffi::Any* val) { + using runtime::DeviceAPI; + DeviceAPI* api = DeviceAPI::Get(device, true); + if (api == nullptr) { + return false; + } + api->GetAttr(device, runtime::kExist, val); + int exists = val->cast(); + if (!exists) { + return false; + } + DeviceAPI::Get(device)->GetAttr(device, flag, val); + return true; +} + +void CheckOrSetAttr(ffi::Map* attrs, const ffi::String& name, + const ffi::String& value) { + auto iter = attrs->find(name); + if (iter == attrs->end()) { + attrs->Set(name, value); + } else { + auto str = (*iter).second.try_cast(); + TVM_FFI_CHECK(str && str.value() == value, ValueError) + << "Expects \"" << name << "\" to be \"" << value << "\", but gets: " << (*iter).second; + } +} + +ffi::Map UpdateROCmAttrs(ffi::Map target) { + CheckOrSetAttr(&target, "mtriple", "amdgcn-amd-amdhsa-hcc"); + std::string arch = "gfx900"; + if (target.count("mcpu")) { + ffi::String mcpu = Downcast(target.at("mcpu")); + arch = ExtractStringWithPrefix(mcpu, "gfx"); + TVM_FFI_CHECK(!arch.empty(), ValueError) + << "ROCm target gets an invalid GFX version: -mcpu=" << mcpu; + } else { + ffi::Any val; + if (const auto f_get_rocm_arch = tvm::ffi::Function::GetGlobal("tvm_callback_rocm_get_arch")) { + arch = (*f_get_rocm_arch)().cast(); + } + target.Set("mcpu", ffi::String(arch)); + } + + ffi::Any val; + int version; + if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kApiVersion, &val)) { + LOG(WARNING) << "Unable to detect ROCm version, assuming >= 3.5"; + version = 305; + } else { + version = val.cast(); + } + if (version < 305) { + ffi::Array mattr; + if (target.count("mattr")) { + mattr = Downcast>(target.at("mattr")); + } + mattr.push_back("-code-object-v3"); + target.Set("mattr", mattr); + } + return target; +} + +void RegisterTargetKind() { + namespace refl = tvm::ffi::reflection; + + TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) + .add_attr_option("mcpu") + .add_attr_option("mtriple") + .add_attr_option>("mattr") + // TODO(masahi): Support querying from a target device + // On RDNA cards, thread_warp_size should be 32 + .add_attr_option("max_num_threads", refl::DefaultValue(256)) + .add_attr_option("max_threads_per_block", refl::DefaultValue(256)) + .add_attr_option("max_shared_memory_per_block", refl::DefaultValue(65536)) + .add_attr_option("thread_warp_size", refl::DefaultValue(64)) + .set_default_keys({"rocm", "gpu"}) + .set_target_canonicalizer(UpdateROCmAttrs); +} + +} // namespace rocm +} // namespace backend + +namespace codegen { +#ifdef TVM_LLVM_VERSION +void RegisterAMDGPUCodegen(); +namespace llvm { +void RegisterROCMIntrinRules(); +} // namespace llvm +#endif +} // namespace codegen +} // namespace tvm + +TVM_FFI_STATIC_INIT_BLOCK() { + tvm::backend::rocm::RegisterTargetKind(); +#ifdef TVM_LLVM_VERSION + tvm::codegen::llvm::RegisterROCMIntrinRules(); + tvm::codegen::RegisterAMDGPUCodegen(); +#endif +} diff --git a/src/target/rocm/rocm_fallback_module.cc b/src/backend/rocm/codegen/rocm_fallback_module.cc similarity index 99% rename from src/target/rocm/rocm_fallback_module.cc rename to src/backend/rocm/codegen/rocm_fallback_module.cc index 4492be30bb6f..5e2ca48e5dee 100644 --- a/src/target/rocm/rocm_fallback_module.cc +++ b/src/backend/rocm/codegen/rocm_fallback_module.cc @@ -33,7 +33,7 @@ #include #include -#include "../../support/bytes_io.h" +#include "../../../support/bytes_io.h" namespace tvm { namespace target { diff --git a/src/target/rocm/rocm_fallback_module.h b/src/backend/rocm/codegen/rocm_fallback_module.h similarity index 98% rename from src/target/rocm/rocm_fallback_module.h rename to src/backend/rocm/codegen/rocm_fallback_module.h index e9ef4cc07390..a83604407d9a 100644 --- a/src/target/rocm/rocm_fallback_module.h +++ b/src/backend/rocm/codegen/rocm_fallback_module.h @@ -38,8 +38,8 @@ #include -#include "../../runtime/metadata.h" -#include "../../support/env.h" +#include "../../../runtime/metadata.h" +#include "../../../support/env.h" namespace tvm { namespace target { diff --git a/src/runtime/rocm/rocm_common.h b/src/backend/rocm/runtime/rocm_common.h similarity index 98% rename from src/runtime/rocm/rocm_common.h rename to src/backend/rocm/runtime/rocm_common.h index 3c07561c38d1..966d38735d02 100644 --- a/src/runtime/rocm/rocm_common.h +++ b/src/backend/rocm/runtime/rocm_common.h @@ -30,7 +30,7 @@ #include -#include "../workspace_pool.h" +#include "../../../runtime/workspace_pool.h" namespace tvm { namespace runtime { diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/backend/rocm/runtime/rocm_device_api.cc similarity index 100% rename from src/runtime/rocm/rocm_device_api.cc rename to src/backend/rocm/runtime/rocm_device_api.cc diff --git a/src/runtime/rocm/rocm_module.cc b/src/backend/rocm/runtime/rocm_module.cc similarity index 98% rename from src/runtime/rocm/rocm_module.cc rename to src/backend/rocm/runtime/rocm_module.cc index b940804d34cc..b4b39de69e54 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/backend/rocm/runtime/rocm_module.cc @@ -36,10 +36,10 @@ #include #include -#include "../../support/bytes_io.h" -#include "../metadata.h" -#include "../pack_args.h" -#include "../thread_storage_scope.h" +#include "../../../runtime/metadata.h" +#include "../../../runtime/pack_args.h" +#include "../../../runtime/thread_storage_scope.h" +#include "../../../support/bytes_io.h" #include "rocm_common.h" namespace tvm { diff --git a/src/target/source/codegen_trn.cc b/src/backend/trn/codegen/codegen_trn.cc similarity index 91% rename from src/target/source/codegen_trn.cc rename to src/backend/trn/codegen/codegen_trn.cc index 6a2eb7168ff4..9b798c3dc8f3 100644 --- a/src/target/source/codegen_trn.cc +++ b/src/backend/trn/codegen/codegen_trn.cc @@ -33,8 +33,8 @@ #include #include -#include "../../runtime/thread_storage_scope.h" -#include "../build_common.h" +#include "../../../runtime/thread_storage_scope.h" +#include "../../../target/build_common.h" namespace tvm { namespace codegen { @@ -360,22 +360,39 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL auto is_op = [&](const Op& compat, const char* canonical_name) { return op->op.same_as(compat) || (op_node != nullptr && op_node->name == canonical_name); }; - if (is_op(builtin::nki_matmul(), "tirx.nki.matmul")) { + static const Op& nki_matmul_op = Op::Get("tirx.nki_matmul"); + static const Op& nki_load_op = Op::Get("tirx.nki_load"); + static const Op& nki_store_op = Op::Get("tirx.nki_store"); + static const Op& nki_tensor_copy_op = Op::Get("tirx.nki_tensor_copy"); + static const Op& nki_activation_op = Op::Get("tirx.nki_activation"); + static const Op& nki_reciprocal_op = Op::Get("tirx.nki_reciprocal"); + static const Op& nki_tensortensor_op = Op::Get("tirx.nki_tensortensor"); + static const Op& nki_tensorscalar_op = Op::Get("tirx.nki_tensorscalar"); + static const Op& nki_memset_op = Op::Get("tirx.nki_memset"); + static const Op& nki_tensorreduce_op = Op::Get("tirx.nki_tensorreduce"); + static const Op& nki_activation_reduce_op = Op::Get("tirx.nki_activation_reduce"); + static const Op& nki_tensorscalar_reduce_op = Op::Get("tirx.nki_tensorscalar_reduce"); + static const Op& nki_identity_op = Op::Get("tirx.nki_identity"); + static const Op& nki_scalar_tensor_tensor_op = Op::Get("tirx.nki_scalar_tensor_tensor"); + static const Op& nki_scalar_tensor_scalar_op = Op::Get("tirx.nki_scalar_tensor_scalar"); + static const Op& nki_affine_select_op = Op::Get("tirx.nki_affine_select"); + + if (is_op(nki_matmul_op, "tirx.nki.matmul")) { TVM_FFI_ICHECK_EQ(op->args.size(), 4); std::string accum = is_one(op->args[3]) ? " += " : " = "; os << PrintExpr(op->args[0]) << accum; ctx_.is_matmul_input = true; os << "nisa.nc_matmul(" << PrintExpr(op->args[1]) << "," << PrintExpr(op->args[2]); - } else if (is_op(builtin::nki_load(), "tirx.nki.load")) { + } else if (is_op(nki_load_op, "tirx.nki.load")) { TVM_FFI_ICHECK_EQ(op->args.size(), 2); os << PrintExpr(op->args[0]) << " = nl.load(" << PrintExpr(op->args[1]); - } else if (is_op(builtin::nki_store(), "tirx.nki.store")) { + } else if (is_op(nki_store_op, "tirx.nki.store")) { TVM_FFI_ICHECK_EQ(op->args.size(), 2); os << "nl.store(" << PrintExpr(op->args[0]) << ", " << PrintExpr(op->args[1]); - } else if (is_op(builtin::nki_tensor_copy(), "tirx.nki.tensor_copy")) { + } else if (is_op(nki_tensor_copy_op, "tirx.nki.tensor_copy")) { TVM_FFI_ICHECK_EQ(op->args.size(), 2); os << PrintExpr(op->args[0]) << " = nisa.tensor_copy(" << PrintExpr(op->args[1]); - } else if (is_op(builtin::nki_activation(), "tirx.nki.activation")) { + } else if (is_op(nki_activation_op, "tirx.nki.activation")) { TVM_FFI_ICHECK_EQ(op->args.size(), 5); // nki_activation(result, data, opcode, bias, scale) TVM_FFI_ICHECK(opcode_map_.count(op->args[2].as()->value)); @@ -383,17 +400,17 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL os << PrintExpr(op->args[0]) << " = nisa.activation(op=" << nki_op << ", data=" << PrintExpr(op->args[1]) << ","; os << "bias=" << PrintExpr(op->args[3]) << ", scale=" << PrintExpr(op->args[4]); - } else if (is_op(builtin::nki_reciprocal(), "tirx.nki.reciprocal")) { + } else if (is_op(nki_reciprocal_op, "tirx.nki.reciprocal")) { TVM_FFI_ICHECK_EQ(op->args.size(), 2); os << PrintExpr(op->args[0]) << " = nisa.reciprocal(" << PrintExpr(op->args[1]); - } else if (is_op(builtin::nki_tensortensor(), "tirx.nki.tensortensor")) { + } else if (is_op(nki_tensortensor_op, "tirx.nki.tensortensor")) { TVM_FFI_ICHECK_EQ(op->args.size(), 4); // nki_tensortensor(result, data1, data2, opcode) TVM_FFI_ICHECK(opcode_map_.count(op->args[3].as()->value)); std::string nki_op = opcode_map_[op->args[3].as()->value]; os << PrintExpr(op->args[0]) << " = nisa.tensor_tensor(" << PrintExpr(op->args[1]) << ", "; os << PrintExpr(op->args[2]) << ", op=" << nki_op; - } else if (is_op(builtin::nki_tensorscalar(), "tirx.nki.tensorscalar")) { + } else if (is_op(nki_tensorscalar_op, "tirx.nki.tensorscalar")) { TVM_FFI_ICHECK_EQ(op->args.size(), 5); // nki_tensorscalar(result, operand0, operand1, opcode, reverse) TVM_FFI_ICHECK(opcode_map_.count(op->args[3].as()->value)); @@ -402,13 +419,13 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL os << PrintExpr(op->args[0]) << " = nisa.tensor_scalar(" << PrintExpr(op->args[1]) << ", operand0="; os << PrintExpr(op->args[2]) << ", op0=" << nki_op << ", reverse0=" << PrintBool(reverse); - } else if (is_op(builtin::nki_memset(), "tirx.nki.memset")) { + } else if (is_op(nki_memset_op, "tirx.nki.memset")) { TVM_FFI_ICHECK_GE(op->args.size(), 2); // result, value os << PrintExpr(op->args[0]) << " = " << PrintExpr(op->args[1]); TVM_FFI_ICHECK(!ctx_.mask.defined()) << "memset cannot have mask"; return; - } else if (is_op(builtin::nki_tensorreduce(), "tirx.nki.tensorreduce")) { + } else if (is_op(nki_tensorreduce_op, "tirx.nki.tensorreduce")) { TVM_FFI_ICHECK(op->args.size() >= 5) << "nki_tensorreduce expects at least 5 arguments, but got " << op->args.size(); // nki_tensorreduce(result, data, opcode, negate, *axes) @@ -418,7 +435,7 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL Array axes(op->args.begin() + 4, op->args.end()); os << PrintExpr(op->args[0]) << " = nisa.tensor_reduce(data=" << PrintExpr(op->args[1]) << ", op=" << nki_op << ", negate=" << PrintBool(negate) << ", axis=" << axes; - } else if (is_op(builtin::nki_activation_reduce(), "tirx.nki.activation_reduce")) { + } else if (is_op(nki_activation_reduce_op, "tirx.nki.activation_reduce")) { TVM_FFI_ICHECK(op->args.size() == 7) << "nki_activation_reduce expects 7 arguments, but got " << op->args.size(); // nki_activation_reduce(reduce_res, act_res, data, opcode, reduce_opcode, bias, scale) @@ -430,7 +447,7 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL << ", op=" << nki_op; os << ", reduce_op=" << reduce_nki_op << ", reduce_res=" << PrintExpr(op->args[0]) << ", bias=" << PrintExpr(op->args[5]) << ", scale=" << PrintExpr(op->args[6]); - } else if (is_op(builtin::nki_tensorscalar_reduce(), "tirx.nki.tensorscalar_reduce")) { + } else if (is_op(nki_tensorscalar_reduce_op, "tirx.nki.tensorscalar_reduce")) { TVM_FFI_ICHECK(op->args.size() == 7) << "nki_tensorscalar_reduce expects 7 arguments, but got " << op->args.size(); // nki_tensorscalar_reduce(reduce_res, tensorscalar_res, operand0, operand1, opcode, @@ -444,7 +461,7 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL << ", op0=" << nki_op << ", operand0=" << PrintExpr(op->args[3]) << ", reduce_op=" << reduce_nki_op << ", reduce_res=" << PrintExpr(op->args[0]) << ", reverse0=" << PrintBool(reverse); - } else if (is_op(builtin::nki_identity(), "tirx.nki.identity")) { + } else if (is_op(nki_identity_op, "tirx.nki.identity")) { // nki_identity(result, size) TVM_FFI_ICHECK_EQ(op->args.size(), 2); auto identity_np_name = name_supply_->FreshName("identity_np"); @@ -454,7 +471,7 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL os << ' '; } os << PrintExpr(op->args[0]) << " = nl.load(" << identity_np_name; - } else if (is_op(builtin::nki_scalar_tensor_tensor(), "tirx.nki.scalar_tensor_tensor")) { + } else if (is_op(nki_scalar_tensor_tensor_op, "tirx.nki.scalar_tensor_tensor")) { TVM_FFI_ICHECK_EQ(op->args.size(), 8); // nki_scalar_tensor_tensor(result, data, operand0, operand1, opcode0, opcode1, reverse0, // reverse1) @@ -468,7 +485,7 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL << ", operand0=" << PrintExpr(op->args[2]) << ", op0=" << nki_op0 << ", reverse0=" << PrintBool(reverse0) << ", operand1=" << PrintExpr(op->args[3]) << ", op1=" << nki_op1 << ", reverse1=" << PrintBool(reverse1); - } else if (is_op(builtin::nki_scalar_tensor_scalar(), "tirx.nki.scalar_tensor_scalar")) { + } else if (is_op(nki_scalar_tensor_scalar_op, "tirx.nki.scalar_tensor_scalar")) { TVM_FFI_ICHECK_EQ(op->args.size(), 8); // nki_scalar_tensor_scalar(result, data, operand0, operand1, opcode0, opcode1, reverse0, // reverse1) @@ -482,7 +499,7 @@ void CodeGenTrainium::VisitExpr_(const CallNode* op, std::ostream& os) { // NOL << ", operand0=" << PrintExpr(op->args[2]) << ", op0=" << nki_op0 << ", reverse0=" << PrintBool(reverse0) << ", operand1=" << PrintExpr(op->args[3]) << ", op1=" << nki_op1 << ", reverse1=" << PrintBool(reverse1); - } else if (is_op(builtin::nki_affine_select(), "tirx.nki.affine_select")) { + } else if (is_op(nki_affine_select_op, "tirx.nki.affine_select")) { TVM_FFI_ICHECK_EQ(op->args.size(), 4); // nki_affine_select(result, pred, true_value, false_value) os << PrintExpr(op->args[0]) << " = nisa.affine_select(pred=" << PrintExpr(op->args[1]) @@ -667,7 +684,11 @@ void CodeGenTrainium::VisitExpr_(const OrNode* op, std::ostream& os) { os << PrintExpr(op->a) << " | " << PrintExpr(op->b); } -TVM_FFI_STATIC_INIT_BLOCK() { +void RegisterTRNCodegen() { + static bool registered = false; + if (registered) return; + registered = true; + namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.trn", BuildTrainium); } diff --git a/src/target/source/codegen_trn.h b/src/backend/trn/codegen/codegen_trn.h similarity index 98% rename from src/target/source/codegen_trn.h rename to src/backend/trn/codegen/codegen_trn.h index 648446513929..2c3b5fd37393 100644 --- a/src/target/source/codegen_trn.h +++ b/src/backend/trn/codegen/codegen_trn.h @@ -30,7 +30,7 @@ #include #include -#include "codegen_c.h" +#include "../../../target/source/codegen_c.h" namespace tvm { namespace codegen { diff --git a/src/backend/trn/codegen/register.cc b/src/backend/trn/codegen/register.cc new file mode 100644 index 000000000000..ca5e5876bbef --- /dev/null +++ b/src/backend/trn/codegen/register.cc @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file register.cc + * \brief Trainium compiler backend static registration. + */ +#include +#include +#include +#include +#include + +namespace tvm { + +namespace backend { +namespace trn { + +void RegisterTargetKind() { + TVM_REGISTER_TARGET_KIND("trn", kDLTrn) + .add_attr_option("partition_size", 128) + .add_attr_option("max_sbuf_size_per_partition", 196608) + .add_attr_option("max_psum_size_per_partition", 16384) + .add_attr_option("num-cores"); +} + +} // namespace trn +} // namespace backend + +namespace codegen { +void RegisterTRNCodegen(); +} // namespace codegen + +namespace tirx { +namespace transform { +void RegisterTRNTransforms(); +} // namespace transform +} // namespace tirx +} // namespace tvm + +TVM_FFI_STATIC_INIT_BLOCK() { + tvm::backend::trn::RegisterTargetKind(); + tvm::codegen::RegisterTRNCodegen(); + tvm::tirx::transform::RegisterTRNTransforms(); +} diff --git a/src/backend/trn/op/register.cc b/src/backend/trn/op/register.cc new file mode 100644 index 000000000000..365d7d6a9651 --- /dev/null +++ b/src/backend/trn/op/register.cc @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file register.cc + * \brief Trainium backend op static registration. + */ +#include +#include + +namespace tvm { +namespace tirx { +namespace builtin { +void RegisterTRNTargetBuiltins(); +} // namespace builtin +} // namespace tirx +} // namespace tvm + +TVM_FFI_STATIC_INIT_BLOCK() { tvm::tirx::builtin::RegisterTRNTargetBuiltins(); } diff --git a/src/tirx/op/target_builtin/trn.cc b/src/backend/trn/op/target_builtin.cc similarity index 86% rename from src/tirx/op/target_builtin/trn.cc rename to src/backend/trn/op/target_builtin.cc index e9df7669cfb1..baf7ce05bb17 100644 --- a/src/tirx/op/target_builtin/trn.cc +++ b/src/backend/trn/op/target_builtin.cc @@ -19,7 +19,7 @@ */ /*! - * \file tir/op/target_builtin/trn.cc + * \file backend/trn/op/target_builtin.cc * * builtin intrinsic operators specific to Trainium target. */ @@ -33,12 +33,21 @@ namespace tvm { namespace tirx { namespace builtin { -#define TIRX_DEFINE_BUILTIN_FUNC(OpName) \ - const Op& OpName() { \ - static const Op& op = Op::Get("tirx." #OpName); \ - return op; \ - } \ - TVM_TIRX_REGISTER_OP(#OpName) +#define TIRX_DEFINE_BUILTIN_FUNC(OpName) \ + OpRegEntry::RegisterOrGet("tirx." #OpName) \ + .set_name() \ + .set_attr("TScriptPrinterName", ffi::String(#OpName), 1) \ + .set_attr("TIRxOpCategory", ffi::String("builtin"), /*plevel=*/1) + +namespace { +void RegisterNKIIntrinsicAliases(); +} + +void RegisterTRNTargetBuiltins() { + // clang-format off +static bool registered = false; +if (registered) return; +registered = true; TIRX_DEFINE_BUILTIN_FUNC(nki_load).set_attr( "TCallEffectKind", static_cast(CallEffectKind::kOpaque)); @@ -88,6 +97,10 @@ TIRX_DEFINE_BUILTIN_FUNC(nki_scalar_tensor_scalar) TIRX_DEFINE_BUILTIN_FUNC(nki_affine_select) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); +RegisterNKIIntrinsicAliases(); + // clang-format on +} + namespace { void RegisterNKIIntrinsic(const char* flat_name) { @@ -138,15 +151,16 @@ const char* kNKIIntrinsics[] = { "nki_tensortensor", }; -const bool kNKIIntrinsicAliasesRegistered = []() { +void RegisterNKIIntrinsicAliases() { for (const char* op_name : kNKIIntrinsics) { RegisterNKIIntrinsic(op_name); } - return true; -}(); +} } // namespace +#undef TIRX_DEFINE_BUILTIN_FUNC + } // namespace builtin } // namespace tirx } // namespace tvm diff --git a/src/backend/trn/transform/lower_trainium_layout.cc b/src/backend/trn/transform/lower_trainium_layout.cc new file mode 100644 index 000000000000..b6b2cdcb3209 --- /dev/null +++ b/src/backend/trn/transform/lower_trainium_layout.cc @@ -0,0 +1,361 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file lower_trainium_layout.cc + * \brief Trainium-specific TIRx layout lowering. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include "../../../arith/ir_mutator_with_analyzer.h" + +namespace tvm { +namespace tirx { + +static bool IsTrainiumLayout(const TileLayoutNode* layout) { + if (layout == nullptr) { + return false; + } + return !std::any_of(layout->shard.begin(), layout->shard.end(), [](const Iter& iter) { + return iter->axis->IsMemoryAxis() && !iter->axis.same_as(Axis::Get("F")) && + !iter->axis.same_as(Axis::Get("P")) && !iter->axis.same_as(Axis::Get("Bank")); + }); +} + +class TrainiumLayoutApplier : public arith::IRMutatorWithAnalyzer { + public: + static std::pair> Lower( + const Stmt& stmt, const ffi::Map buffer_map) { + arith::Analyzer ana; + TrainiumLayoutApplier storage_lower(ana.get()); + std::unordered_map new_buffer_map; + std::vector param_flattened_buffers; + for (const auto& kv : buffer_map) { + if (kv.second->layout.defined()) { + param_flattened_buffers.push_back(storage_lower.GetFlattenedBuffer(kv.second)); + Buffer buffer = kv.second; + auto* writer = buffer.CopyOnWrite(); + writer->layout = std::nullopt; + new_buffer_map[kv.first] = buffer; + } else { + new_buffer_map[kv.first] = kv.second; + } + } + auto new_stmt = storage_lower(stmt); + for (const auto& buf : param_flattened_buffers) { + new_stmt = SeqStmt::Flatten(DeclBuffer(buf), std::move(new_stmt)); + } + return std::make_pair(new_stmt, ffi::Map(new_buffer_map)); + } + + protected: + using IRMutatorWithAnalyzer::VisitExpr_; + using IRMutatorWithAnalyzer::VisitStmt_; + + explicit TrainiumLayoutApplier(arith::AnalyzerObj* analyzer) + : arith::IRMutatorWithAnalyzer(analyzer) {} + + ffi::Any VisitAny(const ffi::Any& any) { + if (any == nullptr) { + return any; + } + if (auto buffer = any.as()) { + return GetFlattenedBuffer(buffer.value()); + } else if (auto prim_expr = any.as()) { + return VisitExpr(prim_expr.value()); + } else if (auto stmt = any.as()) { + return VisitStmt(stmt.value()); + } + return any; + } + + Stmt VisitStmt_(const AllocBufferNode* op) final { + if (!op->buffer->layout.defined()) { + return ffi::GetRef(op); + } + auto buffer = GetFlattenedBuffer(op->buffer, /*is_alloc=*/true); + if (buffer.same_as(op->buffer)) { + return ffi::GetRef(op); + } + auto n = CopyOnWrite(op); + n->buffer = buffer; + return Stmt(n); + } + + Stmt VisitStmt_(const DeclBufferNode* op) final { + auto buffer = GetFlattenedBuffer(op->buffer); + if (buffer.same_as(op->buffer)) { + return ffi::GetRef(op); + } + auto n = CopyOnWrite(op); + n->buffer = buffer; + return Stmt(n); + } + + Buffer GetFlattenedBuffer(Buffer buf, bool is_alloc = false) { + auto it = buffer_remap_.find(buf); + if (it != buffer_remap_.end()) { + return it->second; + } + auto trn_layout = buf->layout.as(); + Buffer flattened; + tirx::BufferNode* writer; + if (IsTrainiumLayout(trn_layout)) { + ffi::Array new_shape = + buf.scope() == "trn.psum" ? ffi::Array{trn_layout->GetSpan(ffi::String("Bank")), + trn_layout->GetSize(ffi::String("P")), + trn_layout->GetSpan(ffi::String("F"))} + : ffi::Array{trn_layout->GetSize(ffi::String("P")), + trn_layout->GetSpan(ffi::String("F"))}; + flattened = buf; + writer = flattened.CopyOnWrite(); + writer->shape = new_shape; + writer->strides = {}; + writer->axis_separators = {}; + } else if (is_alloc) { + if (auto tile_layout = buf->layout.as(); + tile_layout && tile_layout->HasThreadAxis()) { + arith::Analyzer ana; + PrimExpr mem_span = make_const(DataType::Int(32), 1); + for (const auto& iter : tile_layout->shard) { + if (iter->axis->IsMemoryAxis()) { + mem_span = mem_span + (iter->extent - 1) * iter->stride; + } + } + for (const auto& iter : tile_layout->replica) { + if (iter->axis->IsMemoryAxis()) { + mem_span = mem_span + (iter->extent - 1) * iter->stride; + } + } + for (const auto& [axis, off] : tile_layout->offset) { + if (axis->IsMemoryAxis()) { + mem_span = mem_span + off; + } + } + flattened = buf; + writer = flattened.CopyOnWrite(); + writer->shape = {ana->Simplify(mem_span)}; + writer->strides = {}; + writer->axis_separators = {}; + } else { + flattened = buf.GetFlattenedBuffer(); + writer = flattened.CopyOnWrite(); + } + } else { + flattened = buf.GetFlattenedBuffer(); + writer = flattened.CopyOnWrite(); + } + if (flattened->dtype == DataType::Bool()) { + writer->dtype = DataType::Int(8); + } + for (size_t i = 0; i < flattened->shape.size(); ++i) { + writer->shape.Set(i, analyzer_->canonical_simplify(flattened->shape[i])); + } + writer->layout = std::nullopt; + writer->elem_offset = StmtExprMutator::VisitExpr(buf->elem_offset); + + buffer_remap_[buf] = flattened; + return flattened; + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + bool store_returns_bool = (op->value.dtype() == DataType::Bool()); + store = VisitBufferAccess(store); + + if (store_returns_bool) { + TVM_FFI_ICHECK_EQ(store->buffer->dtype, DataType::Int(8)) + << "Expected int8 backing array for boolean tensor"; + auto writer = store.CopyOnWrite(); + writer->value = tvm::cast(DataType::Int(8), store->value); + return std::move(store); + } + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + bool load_returns_bool = (op->dtype == DataType::Bool()); + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + load = VisitBufferAccess(load); + if (load_returns_bool) { + TVM_FFI_ICHECK_EQ(load->buffer->dtype, DataType::Int(8)) + << "Expected int8 backing array for boolean tensor"; + load.CopyOnWrite()->dtype = DataType::Int(8); + return tvm::cast(DataType::Bool(), load); + } else { + return std::move(load); + } + } + + Stmt VisitStmt_(const tirx::TilePrimitiveCallNode* op) final { + ffi::Array args = op->args; + args.MutateByApply([this](ffi::Any arg) -> ffi::Any { return VisitAny(arg); }); + if (args.same_as(op->args)) { + return ffi::GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->args = std::move(args); + return Stmt(n); + } + } + + ffi::Array GetSimplifiedElemOffset(const Buffer& buffer, + const ffi::Array& indices) { + if (buffer->layout.defined()) { + auto tile_layout = buffer->layout.value().as(); + if (IsTrainiumLayout(tile_layout)) { + auto coord = buffer->layout.value()->Apply(indices, buffer->shape); + std::vector res; + for (const auto& axis : buffer.scope() == "trn.psum" + ? ffi::Array{"Bank", "P", "F"} + : ffi::Array{"P", "F"}) { + auto it = coord.find(ffi::String(axis)); + if (it != coord.end()) { + res.push_back(analyzer_->Simplify((*it).second)); + } else { + res.push_back(0); + } + } + return res; + } + if (tile_layout && tile_layout->HasThreadAxis()) { + LOG(FATAL) << "Cannot lower direct BufferLoad/BufferStore on a buffer with thread-axis " + << "layout: unable to verify that the coordinate matches the current thread. " + << "Use .view() + .local() to decompose thread and memory axes."; + } + auto res = buffer->layout.value()->Canonicalize()->Apply(indices, buffer->shape); + TVM_FFI_ICHECK_EQ(res.size(), 1) << "Expected a single element offset"; + return {analyzer_->Simplify((*res.begin()).second)}; + } + auto flattened_indices = buffer->ElemOffset(indices, true); + TVM_FFI_ICHECK_EQ(flattened_indices.size(), 1) << "Expected a single element offset"; + return {analyzer_->Simplify(flattened_indices[0])}; + } + + template + Node VisitBufferAccess(Node node) { + TVM_FFI_ICHECK(node->buffer.defined()); + if (!node->buffer->layout.defined()) { + return node; + } + auto flattened_indices = GetSimplifiedElemOffset(node->buffer, node->indices); + Buffer flattened_buffer = GetFlattenedBuffer(node->buffer); + auto writer = node.CopyOnWrite(); + writer->buffer = flattened_buffer; + writer->indices = flattened_indices; + return node; + } + + std::unordered_map buffer_remap_; +}; + +class TrainiumBufferOffsetRemover : public StmtExprMutator { + public: + static Stmt Remove(const Stmt& stmt) { return TrainiumBufferOffsetRemover()(stmt); } + + private: + PrimExpr VisitExpr_(const tirx::CallNode* call) final { + if (call->op.same_as(tirx::builtin::buffer_offset())) { + auto buffer_load = Downcast(call->args[0]); + TVM_FFI_ICHECK_EQ(buffer_load->indices.size(), 1) << "Expected a single index"; + return buffer_load->indices[0]; + } + return StmtExprMutator::VisitExpr_(call); + } + + Stmt VisitStmt_(const DeclBufferNode* op) { + auto buffer = op->buffer; + auto elem_offset = this->VisitExpr(buffer->elem_offset); + if (elem_offset.same_as(buffer->elem_offset)) { + return StmtExprMutator::VisitStmt_(op); + } else { + auto n_buffer = buffer.CopyOnWrite(); + n_buffer->elem_offset = std::move(elem_offset); + buffer_remap_[op->buffer] = buffer; + auto n = CopyOnWrite(op); + n->buffer = ffi::GetRef(n_buffer); + return Stmt(n); + } + } + + using StmtExprMutator::VisitExpr_; + using StmtExprMutator::VisitStmt_; + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + store = VisitBufferAccess(store); + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + load = VisitBufferAccess(load); + return std::move(load); + } + + template + Node VisitBufferAccess(Node node) { + TVM_FFI_ICHECK(node->buffer.defined()); + auto it = buffer_remap_.find(node->buffer); + if (it != buffer_remap_.end()) { + auto writer = node.CopyOnWrite(); + writer->buffer = it->second; + return node; + } + return node; + } + + std::unordered_map buffer_remap_; +}; + +namespace transform { + +Pass LowerTrainiumLayout() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + std::tie(n->body, n->buffer_map) = TrainiumLayoutApplier::Lower(n->body, n->buffer_map); + n->body = TrainiumBufferOffsetRemover::Remove(n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tirx.backend.trn.LowerTrainiumLayout", {}); +} + +void RegisterTRNTransforms() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("tirx.backend.trn.transform.LowerTrainiumLayout", LowerTrainiumLayout); +} + +} // namespace transform +} // namespace tirx +} // namespace tvm diff --git a/src/target/vulkan/build_vulkan.cc b/src/backend/vulkan/codegen/build_vulkan.cc similarity index 88% rename from src/target/vulkan/build_vulkan.cc rename to src/backend/vulkan/codegen/build_vulkan.cc index ffadf5a24004..8fc846d1e33c 100644 --- a/src/target/vulkan/build_vulkan.cc +++ b/src/backend/vulkan/codegen/build_vulkan.cc @@ -27,9 +27,9 @@ #include #include -#include "../../runtime/vulkan/spirv_shader.h" -#include "../../support/bytes_io.h" -#include "../build_common.h" +#include "../../../support/bytes_io.h" +#include "../../../target/build_common.h" +#include "../runtime/spirv_shader.h" #include "spirv_utils.h" #include "vulkan_fallback_module.h" @@ -41,7 +41,7 @@ ffi::Module BuildSPIRV(IRModule mod, Target target) { // Serialize each SPIRVShader to ffi::Bytes for the unified per-kernel // smap shape. Each value is a self-packed SPIRVShader (flag + data // vector); the Vulkan runtime (USE_VULKAN=ON) deserializes via the - // inverse helper in src/runtime/vulkan/vulkan_module.cc. + // inverse helper in src/backend/vulkan/runtime/vulkan_module.cc. ffi::Map shader_bytes; for (auto& kv : smap) { std::string buf; @@ -57,7 +57,11 @@ ffi::Module BuildSPIRV(IRModule mod, Target target) { ExtractFuncInfo(mod), std::move(source)); } -TVM_FFI_STATIC_INIT_BLOCK() { +void RegisterVulkanCodegen() { + static bool registered = false; + if (registered) return; + registered = true; + namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.vulkan", [](IRModule mod, Target target) { return BuildSPIRV(mod, target); }); diff --git a/src/target/vulkan/codegen_spirv.cc b/src/backend/vulkan/codegen/codegen_spirv.cc similarity index 98% rename from src/target/vulkan/codegen_spirv.cc rename to src/backend/vulkan/codegen/codegen_spirv.cc index 0afd35026916..4828dd2d5eb3 100644 --- a/src/target/vulkan/codegen_spirv.cc +++ b/src/backend/vulkan/codegen/codegen_spirv.cc @@ -30,9 +30,9 @@ #include -#include "../../runtime/pack_args.h" -#include "../../runtime/vulkan/vulkan_common.h" -#include "../../tirx/transform/ir_utils.h" +#include "../../../runtime/pack_args.h" +#include "../../../tirx/transform/ir_utils.h" +#include "../runtime/vulkan_common.h" namespace tvm { namespace codegen { @@ -401,7 +401,14 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { << "SPIR-V shader cannot make extern calls. Graph contains extern \"" << Downcast(op->args[0]) << "\""; return spirv::Value(); - } else if (op->op.same_as(builtin::tvm_fill_fragment())) { + } + + static const Op& tvm_fill_fragment_op = Op::Get("tirx.tvm_fill_fragment"); + static const Op& tvm_load_matrix_sync_op = Op::Get("tirx.tvm_load_matrix_sync"); + static const Op& tvm_mma_sync_op = Op::Get("tirx.tvm_mma_sync"); + static const Op& tvm_store_matrix_sync_op = Op::Get("tirx.tvm_store_matrix_sync"); + + if (op->op.same_as(tvm_fill_fragment_op)) { TVM_FFI_ICHECK_EQ(op->args.size(), 6U); const VarNode* buffer_node = op->args[0].as(); TVM_FFI_ICHECK(buffer_node && fragment_info_.count(buffer_node)); @@ -420,7 +427,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { builder_->MakeInst(spv::OpStore, ptr, init_val, spv::MemoryAccessMaskNone); return spirv::Value(); - } else if (op->op.same_as(builtin::tvm_load_matrix_sync())) { + } else if (op->op.same_as(tvm_load_matrix_sync_op)) { TVM_FFI_ICHECK_EQ(op->args.size(), 8U); const VarNode* buffer_node = op->args[0].as(); TVM_FFI_ICHECK(buffer_node && fragment_info_.count(buffer_node)); @@ -444,7 +451,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { (layout != "row_major") ? t_val : f_val); builder_->MakeInst(spv::OpStore, dst_ptr, loaded, spv::MemoryAccessMaskNone); return spirv::Value(); - } else if (op->op.same_as(builtin::tvm_mma_sync())) { + } else if (op->op.same_as(tvm_mma_sync_op)) { const VarNode* buffer_d = op->args[0].as(); const VarNode* buffer_a = op->args[2].as(); const VarNode* buffer_b = op->args[4].as(); @@ -481,7 +488,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { loaded_a, loaded_b, loaded_c); builder_->MakeInst(spv::OpStore, ptr_d, result, spv::MemoryAccessMaskNone); return spirv::Value(); - } else if (op->op.same_as(builtin::tvm_store_matrix_sync())) { + } else if (op->op.same_as(tvm_store_matrix_sync_op)) { TVM_FFI_ICHECK_EQ(op->args.size(), 8U); const VarNode* buffer_node = op->args[0].as(); PrimExpr index = op->args[4]; diff --git a/src/target/vulkan/codegen_spirv.h b/src/backend/vulkan/codegen/codegen_spirv.h similarity index 98% rename from src/target/vulkan/codegen_spirv.h rename to src/backend/vulkan/codegen/codegen_spirv.h index e0e41b9b1526..46fbcb696b6f 100644 --- a/src/target/vulkan/codegen_spirv.h +++ b/src/backend/vulkan/codegen/codegen_spirv.h @@ -37,8 +37,8 @@ #include #include -#include "../../runtime/thread_storage_scope.h" -#include "../../runtime/vulkan/spirv_shader.h" +#include "../../../runtime/thread_storage_scope.h" +#include "../runtime/spirv_shader.h" #include "ir_builder.h" #include "spirv_support.h" diff --git a/src/target/vulkan/intrin_rule_spirv.cc b/src/backend/vulkan/codegen/intrin_rule_spirv.cc similarity index 92% rename from src/target/vulkan/intrin_rule_spirv.cc rename to src/backend/vulkan/codegen/intrin_rule_spirv.cc index 4b1ffc4b6d7f..14287562d9e4 100644 --- a/src/target/vulkan/intrin_rule_spirv.cc +++ b/src/backend/vulkan/codegen/intrin_rule_spirv.cc @@ -27,7 +27,7 @@ #include #include -#include "../intrin_rule.h" +#include "../../../target/intrin_rule.h" namespace tvm { namespace codegen { @@ -61,6 +61,13 @@ inline PrimExpr DispatchGLSLPureIntrin(const PrimExpr& e) { namespace intrin { using tirx::FLowerIntrinsic; + +void RegisterVulkanLowerIntrinRules() { + static bool registered = false; + if (registered) return; + registered = true; + + // clang-format off TVM_REGISTER_OP("tirx.floor") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); @@ -140,10 +147,19 @@ TVM_REGISTER_OP("tirx.pow") TVM_REGISTER_OP("tirx.erf") .set_attr("vulkan.FLowerIntrinsic", codegen::intrin ::DispatchFastErf); + // clang-format on +} } // namespace intrin namespace legalize { using tirx::FLegalize; + +void RegisterVulkanLegalizeRules() { + static bool registered = false; + if (registered) return; + registered = true; + + // clang-format off TVM_REGISTER_OP("tirx.clz") .set_attr("vulkan.FLegalize", [](const PrimExpr& e) -> PrimExpr { const tirx::CallNode* call = e.as(); @@ -166,7 +182,15 @@ TVM_REGISTER_OP("tirx.clz") } return PrimExpr(arg.dtype().bits() - 1) - msb; }); + // clang-format on +} } // namespace legalize + +void RegisterVulkanIntrinRules() { + intrin::RegisterVulkanLowerIntrinRules(); + legalize::RegisterVulkanLegalizeRules(); +} + } // namespace spirv } // namespace codegen } // namespace tvm diff --git a/src/target/vulkan/ir_builder.cc b/src/backend/vulkan/codegen/ir_builder.cc similarity index 100% rename from src/target/vulkan/ir_builder.cc rename to src/backend/vulkan/codegen/ir_builder.cc diff --git a/src/target/vulkan/ir_builder.h b/src/backend/vulkan/codegen/ir_builder.h similarity index 100% rename from src/target/vulkan/ir_builder.h rename to src/backend/vulkan/codegen/ir_builder.h diff --git a/src/backend/vulkan/codegen/register.cc b/src/backend/vulkan/codegen/register.cc new file mode 100644 index 000000000000..92d0d115a84d --- /dev/null +++ b/src/backend/vulkan/codegen/register.cc @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file register.cc + * \brief Vulkan compiler backend static registration. + */ +#include +#include +#include +#include +#include + +namespace tvm { +namespace backend { +namespace vulkan { + +void RegisterTargetKind() { + namespace refl = tvm::ffi::reflection; + + TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) + .add_attr_option>("mattr") + .add_attr_option("supports_float16") + .add_attr_option("supports_float32", refl::DefaultValue(true)) + .add_attr_option("supports_float64") + .add_attr_option("supports_int8") + .add_attr_option("supports_int16") + .add_attr_option("supports_int32", refl::DefaultValue(true)) + .add_attr_option("supports_int64") + .add_attr_option("supports_8bit_buffer") + .add_attr_option("supports_16bit_buffer") + .add_attr_option("supports_storage_buffer_storage_class") + .add_attr_option("supports_push_descriptor") + .add_attr_option("supports_dedicated_allocation") + .add_attr_option("supports_integer_dot_product") + .add_attr_option("supports_cooperative_matrix") + .add_attr_option("supported_subgroup_operations") + .add_attr_option("max_num_threads", refl::DefaultValue(256)) + .add_attr_option("max_threads_per_block", refl::DefaultValue(256)) + .add_attr_option("thread_warp_size", refl::DefaultValue(1)) + .add_attr_option("max_block_size_x") + .add_attr_option("max_block_size_y") + .add_attr_option("max_block_size_z") + .add_attr_option("max_push_constants_size") + .add_attr_option("max_uniform_buffer_range") + .add_attr_option("max_storage_buffer_range") + .add_attr_option("max_per_stage_descriptor_storage_buffer") + .add_attr_option("max_shared_memory_per_block") + .add_attr_option("device_type") + .add_attr_option("device_name") + .add_attr_option("driver_name") + .add_attr_option("driver_version") + .add_attr_option("vulkan_api_version") + .add_attr_option("max_spirv_version") + .set_default_keys({"vulkan", "gpu"}); +} + +} // namespace vulkan +} // namespace backend + +#ifdef TVM_ENABLE_SPIRV +namespace codegen { +void RegisterVulkanCodegen(); +namespace spirv { +void RegisterVulkanIntrinRules(); +} // namespace spirv +} // namespace codegen +#endif +} // namespace tvm + +TVM_FFI_STATIC_INIT_BLOCK() { + tvm::backend::vulkan::RegisterTargetKind(); +#ifdef TVM_ENABLE_SPIRV + tvm::codegen::spirv::RegisterVulkanIntrinRules(); + tvm::codegen::RegisterVulkanCodegen(); +#endif +} diff --git a/src/target/vulkan/spirv_support.cc b/src/backend/vulkan/codegen/spirv_support.cc similarity index 100% rename from src/target/vulkan/spirv_support.cc rename to src/backend/vulkan/codegen/spirv_support.cc diff --git a/src/target/vulkan/spirv_support.h b/src/backend/vulkan/codegen/spirv_support.h similarity index 100% rename from src/target/vulkan/spirv_support.h rename to src/backend/vulkan/codegen/spirv_support.h diff --git a/src/target/vulkan/spirv_utils.cc b/src/backend/vulkan/codegen/spirv_utils.cc similarity index 98% rename from src/target/vulkan/spirv_utils.cc rename to src/backend/vulkan/codegen/spirv_utils.cc index 4dd79fdbec1c..11aecf1c43d3 100644 --- a/src/target/vulkan/spirv_utils.cc +++ b/src/backend/vulkan/codegen/spirv_utils.cc @@ -37,8 +37,8 @@ #include #include -#include "../../runtime/vulkan/spirv_shader.h" -#include "../../support/utils.h" +#include "../../../support/utils.h" +#include "../runtime/spirv_shader.h" namespace tvm { namespace codegen { diff --git a/src/target/vulkan/spirv_utils.h b/src/backend/vulkan/codegen/spirv_utils.h similarity index 97% rename from src/target/vulkan/spirv_utils.h rename to src/backend/vulkan/codegen/spirv_utils.h index 03b98ea1b162..7f661cb0304f 100644 --- a/src/target/vulkan/spirv_utils.h +++ b/src/backend/vulkan/codegen/spirv_utils.h @@ -26,7 +26,7 @@ #include #include -#include "../../runtime/vulkan/spirv_shader.h" +#include "../runtime/spirv_shader.h" namespace tvm { namespace codegen { diff --git a/src/target/vulkan/vulkan_fallback_module.cc b/src/backend/vulkan/codegen/vulkan_fallback_module.cc similarity index 95% rename from src/target/vulkan/vulkan_fallback_module.cc rename to src/backend/vulkan/codegen/vulkan_fallback_module.cc index 0977085894b9..c6bc75021db4 100644 --- a/src/target/vulkan/vulkan_fallback_module.cc +++ b/src/backend/vulkan/codegen/vulkan_fallback_module.cc @@ -34,7 +34,7 @@ #include #include -#include "../../support/bytes_io.h" +#include "../../../support/bytes_io.h" namespace tvm { namespace target { @@ -66,7 +66,7 @@ class VulkanFallbackModuleNode : public ffi::ModuleObj { ffi::Bytes SaveToBytes() const final { // NOTE: serialization format MUST remain byte-identical to - // VulkanModuleNode::SaveToBytes in src/runtime/vulkan/vulkan_module.cc + // VulkanModuleNode::SaveToBytes in src/backend/vulkan/runtime/vulkan_module.cc // (the source of truth). Both produce a kind="vulkan" artifact that // the loader (ffi.Module.load_from_bytes.vulkan, registered only when // USE_VULKAN=ON) deserializes. If the real impl's format changes, @@ -76,7 +76,7 @@ class VulkanFallbackModuleNode : public ffi::ModuleObj { // 3 fields only — the source map is in-memory inspection material and // is NEVER serialized (matches upstream behavior for all backends). // Each value in `smap_` is a self-packed SPIRVShader (flag + data - // vector); see src/runtime/vulkan/spirv_shader.h. + // vector); see src/backend/vulkan/runtime/spirv_shader.h. std::string buffer; support::BytesOutStream stream(&buffer); stream.Write(fmt_); @@ -102,7 +102,7 @@ class VulkanFallbackModuleNode : public ffi::ModuleObj { // Per-kernel payload: kernel-name -> bytes. Each value is a // serialized SPIRVShader (flag + uint32_t data segment); the runtime // (USE_VULKAN=ON) deserializes via the inverse helper in - // src/runtime/vulkan/vulkan_module.cc. Multi-shader uniform + // src/backend/vulkan/runtime/vulkan_module.cc. Multi-shader uniform // Map across all multi-shader backends. ffi::Map smap_; // Format identifier — always "vulkan" today. diff --git a/src/target/vulkan/vulkan_fallback_module.h b/src/backend/vulkan/codegen/vulkan_fallback_module.h similarity index 98% rename from src/target/vulkan/vulkan_fallback_module.h rename to src/backend/vulkan/codegen/vulkan_fallback_module.h index 0113429db982..b75ebae71102 100644 --- a/src/target/vulkan/vulkan_fallback_module.h +++ b/src/backend/vulkan/codegen/vulkan_fallback_module.h @@ -38,8 +38,8 @@ #include -#include "../../runtime/metadata.h" -#include "../../support/env.h" +#include "../../../runtime/metadata.h" +#include "../../../support/env.h" namespace tvm { namespace target { diff --git a/src/runtime/vulkan/README.md b/src/backend/vulkan/runtime/README.md similarity index 100% rename from src/runtime/vulkan/README.md rename to src/backend/vulkan/runtime/README.md diff --git a/src/runtime/vulkan/spirv_shader.h b/src/backend/vulkan/runtime/spirv_shader.h similarity index 100% rename from src/runtime/vulkan/spirv_shader.h rename to src/backend/vulkan/runtime/spirv_shader.h diff --git a/src/runtime/vulkan/thread_map.h b/src/backend/vulkan/runtime/thread_map.h similarity index 100% rename from src/runtime/vulkan/thread_map.h rename to src/backend/vulkan/runtime/thread_map.h diff --git a/src/runtime/vulkan/vulkan_amdrgp.cc b/src/backend/vulkan/runtime/vulkan_amdrgp.cc similarity index 100% rename from src/runtime/vulkan/vulkan_amdrgp.cc rename to src/backend/vulkan/runtime/vulkan_amdrgp.cc diff --git a/src/runtime/vulkan/vulkan_amdrgp.h b/src/backend/vulkan/runtime/vulkan_amdrgp.h similarity index 100% rename from src/runtime/vulkan/vulkan_amdrgp.h rename to src/backend/vulkan/runtime/vulkan_amdrgp.h diff --git a/src/runtime/vulkan/vulkan_buffer.cc b/src/backend/vulkan/runtime/vulkan_buffer.cc similarity index 100% rename from src/runtime/vulkan/vulkan_buffer.cc rename to src/backend/vulkan/runtime/vulkan_buffer.cc diff --git a/src/runtime/vulkan/vulkan_buffer.h b/src/backend/vulkan/runtime/vulkan_buffer.h similarity index 100% rename from src/runtime/vulkan/vulkan_buffer.h rename to src/backend/vulkan/runtime/vulkan_buffer.h diff --git a/src/runtime/vulkan/vulkan_common.cc b/src/backend/vulkan/runtime/vulkan_common.cc similarity index 100% rename from src/runtime/vulkan/vulkan_common.cc rename to src/backend/vulkan/runtime/vulkan_common.cc diff --git a/src/runtime/vulkan/vulkan_common.h b/src/backend/vulkan/runtime/vulkan_common.h similarity index 100% rename from src/runtime/vulkan/vulkan_common.h rename to src/backend/vulkan/runtime/vulkan_common.h diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/backend/vulkan/runtime/vulkan_device.cc similarity index 99% rename from src/runtime/vulkan/vulkan_device.cc rename to src/backend/vulkan/runtime/vulkan_device.cc index f1d3dd2c626e..40608bc3c86b 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/backend/vulkan/runtime/vulkan_device.cc @@ -24,7 +24,7 @@ #include #include -#include "../../support/utils.h" +#include "../../../support/utils.h" #include "vulkan_common.h" #include "vulkan_device.h" #include "vulkan_device_api.h" diff --git a/src/runtime/vulkan/vulkan_device.h b/src/backend/vulkan/runtime/vulkan_device.h similarity index 100% rename from src/runtime/vulkan/vulkan_device.h rename to src/backend/vulkan/runtime/vulkan_device.h diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/backend/vulkan/runtime/vulkan_device_api.cc similarity index 100% rename from src/runtime/vulkan/vulkan_device_api.cc rename to src/backend/vulkan/runtime/vulkan_device_api.cc diff --git a/src/runtime/vulkan/vulkan_device_api.h b/src/backend/vulkan/runtime/vulkan_device_api.h similarity index 99% rename from src/runtime/vulkan/vulkan_device_api.h rename to src/backend/vulkan/runtime/vulkan_device_api.h index c39d5754d8cd..e41ea526b3dc 100644 --- a/src/runtime/vulkan/vulkan_device_api.h +++ b/src/backend/vulkan/runtime/vulkan_device_api.h @@ -26,7 +26,7 @@ #include #include -#include "../workspace_pool.h" +#include "../../../runtime/workspace_pool.h" #include "thread_map.h" #include "vulkan/vulkan_core.h" #include "vulkan_device.h" diff --git a/src/runtime/vulkan/vulkan_instance.cc b/src/backend/vulkan/runtime/vulkan_instance.cc similarity index 99% rename from src/runtime/vulkan/vulkan_instance.cc rename to src/backend/vulkan/runtime/vulkan_instance.cc index 92ee82fe1f8a..0b375c7118f9 100644 --- a/src/runtime/vulkan/vulkan_instance.cc +++ b/src/backend/vulkan/runtime/vulkan_instance.cc @@ -24,7 +24,7 @@ #include #include -#include "../../support/utils.h" +#include "../../../support/utils.h" #include "vulkan_common.h" namespace tvm { diff --git a/src/runtime/vulkan/vulkan_instance.h b/src/backend/vulkan/runtime/vulkan_instance.h similarity index 100% rename from src/runtime/vulkan/vulkan_instance.h rename to src/backend/vulkan/runtime/vulkan_instance.h diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/backend/vulkan/runtime/vulkan_module.cc similarity index 92% rename from src/runtime/vulkan/vulkan_module.cc rename to src/backend/vulkan/runtime/vulkan_module.cc index 78de679cb9ca..af44396402a1 100644 --- a/src/runtime/vulkan/vulkan_module.cc +++ b/src/backend/vulkan/runtime/vulkan_module.cc @@ -21,7 +21,7 @@ * \file vulkan_module.cc * \brief Plugin-only Vulkan runtime module. Built only when USE_VULKAN=ON. * No exported header — codegen-side construction goes through - * src/target/vulkan/vulkan_fallback_module.h:VulkanModuleCreateWithFallback, + * src/backend/vulkan/codegen/vulkan_fallback_module.h:VulkanModuleCreateWithFallback, * which dispatches to "ffi.Module.create.vulkan" registered below * when this file is linked into the build. */ @@ -32,7 +32,7 @@ #include #include -#include "../../support/bytes_io.h" +#include "../../../support/bytes_io.h" #include "spirv_shader.h" #include "vulkan_wrapped_func.h" @@ -43,7 +43,7 @@ namespace vulkan { /*! * \brief Deserialize a SPIRVShader from ffi::Bytes. * Format: flag (uint32_t) followed by data (vector) — matches - * the SPIRVShader::Save format in src/runtime/vulkan/spirv_shader.h. + * the SPIRVShader::Save format in src/backend/vulkan/runtime/spirv_shader.h. */ static SPIRVShader DeserializeSPIRVShader(const ffi::Bytes& bytes) { support::BytesInStream stream(bytes); @@ -83,7 +83,7 @@ static ffi::Module VulkanModuleLoadFromBytes(const ffi::Bytes& bytes) { TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; // Registry: "ffi.Module.create.vulkan" — codegen-time Vulkan module factory. - // Used by src/target/vulkan/vulkan_fallback_module.h:VulkanModuleCreateWithFallback. + // Used by src/backend/vulkan/codegen/vulkan_fallback_module.h:VulkanModuleCreateWithFallback. // Registry: "ffi.Module.load_from_bytes.vulkan" — disk loader. Only this // (real) module registers a loader; the fallback is codegen-only. refl::GlobalDef() diff --git a/src/runtime/vulkan/vulkan_stream.cc b/src/backend/vulkan/runtime/vulkan_stream.cc similarity index 99% rename from src/runtime/vulkan/vulkan_stream.cc rename to src/backend/vulkan/runtime/vulkan_stream.cc index 49ed530a6102..96d773c06d13 100644 --- a/src/runtime/vulkan/vulkan_stream.cc +++ b/src/backend/vulkan/runtime/vulkan_stream.cc @@ -19,7 +19,7 @@ #include "vulkan_stream.h" -#include "../../support/utils.h" +#include "../../../support/utils.h" #include "vulkan_device.h" namespace tvm { diff --git a/src/runtime/vulkan/vulkan_stream.h b/src/backend/vulkan/runtime/vulkan_stream.h similarity index 100% rename from src/runtime/vulkan/vulkan_stream.h rename to src/backend/vulkan/runtime/vulkan_stream.h diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/backend/vulkan/runtime/vulkan_wrapped_func.cc similarity index 99% rename from src/runtime/vulkan/vulkan_wrapped_func.cc rename to src/backend/vulkan/runtime/vulkan_wrapped_func.cc index f6e17522bffb..d352c01211bd 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/backend/vulkan/runtime/vulkan_wrapped_func.cc @@ -24,7 +24,7 @@ #include -#include "../../support/bytes_io.h" +#include "../../../support/bytes_io.h" #include "vulkan_device_api.h" namespace tvm { @@ -408,7 +408,7 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, ffi::Bytes VulkanModuleNode::SaveToBytes() const { // NOTE: serialization format MUST remain byte-identical to // target::VulkanFallbackModuleNode::SaveToBytes in - // src/target/vulkan/vulkan_fallback_module.cc. This file is the + // src/backend/vulkan/codegen/vulkan_fallback_module.cc. This file is the // source of truth; the fallback follows. // 3 fields only — the source map is in-memory inspection material and // is NEVER serialized (matches upstream behavior for all backends). diff --git a/src/runtime/vulkan/vulkan_wrapped_func.h b/src/backend/vulkan/runtime/vulkan_wrapped_func.h similarity index 97% rename from src/runtime/vulkan/vulkan_wrapped_func.h rename to src/backend/vulkan/runtime/vulkan_wrapped_func.h index 6e023ad0a956..8eabfd5c3cbf 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.h +++ b/src/backend/vulkan/runtime/vulkan_wrapped_func.h @@ -27,9 +27,9 @@ #include #include -#include "../metadata.h" -#include "../pack_args.h" -#include "../thread_storage_scope.h" +#include "../../../runtime/metadata.h" +#include "../../../runtime/pack_args.h" +#include "../../../runtime/thread_storage_scope.h" #include "spirv_shader.h" #include "vulkan/vulkan_core.h" #include "vulkan_common.h" diff --git a/src/target/webgpu/codegen_webgpu.cc b/src/backend/webgpu/codegen/codegen_webgpu.cc similarity index 98% rename from src/target/webgpu/codegen_webgpu.cc rename to src/backend/webgpu/codegen/codegen_webgpu.cc index 48e4cc87b60e..08c75ed8404b 100644 --- a/src/target/webgpu/codegen_webgpu.cc +++ b/src/backend/webgpu/codegen/codegen_webgpu.cc @@ -36,12 +36,12 @@ #include #include -#include "../../arith/pattern_match.h" -#include "../../runtime/file_utils.h" -#include "../../runtime/metadata.h" -#include "../../runtime/thread_storage_scope.h" -#include "../../support/bytes_io.h" -#include "../build_common.h" +#include "../../../arith/pattern_match.h" +#include "../../../runtime/file_utils.h" +#include "../../../runtime/metadata.h" +#include "../../../runtime/thread_storage_scope.h" +#include "../../../support/bytes_io.h" +#include "../../../target/build_common.h" #include "webgpu_fallback_module.h" namespace tvm { @@ -741,7 +741,7 @@ void CodeGenWebGPU::VisitStmt_(const ContinueNode* op) { //------------------------------------------------- // // The "C++ side" canonical WebGPU module is `WebGPUFallbackModuleNode` in -// src/target/webgpu/webgpu_fallback_module.{h,cc} — there is no native +// src/backend/webgpu/codegen/webgpu_fallback_module.{h,cc} — there is no native // WebGPU runtime in the C++ tree (the real receiver is the wasm runtime // in web/emcc/webgpu_runtime.cc). ffi::Module BuildWebGPU(IRModule mod, Target target) { @@ -783,7 +783,11 @@ ffi::Module BuildWebGPU(IRModule mod, Target target) { std::move(fmap), std::move(source)); } -TVM_FFI_STATIC_INIT_BLOCK() { +void RegisterWebGPUCodegen() { + static bool registered = false; + if (registered) return; + registered = true; + namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("target.build.webgpu", [](IRModule mod, Target target) { return BuildWebGPU(mod, target); }); diff --git a/src/target/webgpu/codegen_webgpu.h b/src/backend/webgpu/codegen/codegen_webgpu.h similarity index 98% rename from src/target/webgpu/codegen_webgpu.h rename to src/backend/webgpu/codegen/codegen_webgpu.h index 061d631e5dc9..4c873ac3db18 100644 --- a/src/target/webgpu/codegen_webgpu.h +++ b/src/backend/webgpu/codegen/codegen_webgpu.h @@ -31,7 +31,7 @@ #include -#include "../source/codegen_c.h" +#include "../../../target/source/codegen_c.h" namespace tvm { namespace codegen { diff --git a/src/target/webgpu/intrin_rule_webgpu.cc b/src/backend/webgpu/codegen/intrin_rule_webgpu.cc similarity index 92% rename from src/target/webgpu/intrin_rule_webgpu.cc rename to src/backend/webgpu/codegen/intrin_rule_webgpu.cc index 14dfd7959146..03815d0a667b 100644 --- a/src/target/webgpu/intrin_rule_webgpu.cc +++ b/src/backend/webgpu/codegen/intrin_rule_webgpu.cc @@ -24,7 +24,7 @@ #include #include -#include "../intrin_rule.h" +#include "../../../target/intrin_rule.h" namespace tvm { namespace codegen { @@ -36,12 +36,16 @@ using tirx::FLowerIntrinsic; struct WebGPUWarpIntrinsic { const Op operator()(DataType t, const Op& orig_op) const { if (orig_op.same_as(builtin::tvm_warp_shuffle())) { - return Op::Get("tirx.webgpu.subgroup_shuffle"); + static const Op& webgpu_subgroup_shuffle_op = Op::Get("tirx.webgpu.subgroup_shuffle"); + return webgpu_subgroup_shuffle_op; } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { - return Op::Get("tirx.webgpu.subgroup_shuffle_up"); + static const Op& webgpu_subgroup_shuffle_up_op = Op::Get("tirx.webgpu.subgroup_shuffle_up"); + return webgpu_subgroup_shuffle_up_op; } else { TVM_FFI_ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); - return Op::Get("tirx.webgpu.subgroup_shuffle_down"); + static const Op& webgpu_subgroup_shuffle_down_op = + Op::Get("tirx.webgpu.subgroup_shuffle_down"); + return webgpu_subgroup_shuffle_down_op; } } }; @@ -56,6 +60,12 @@ static PrimExpr DispatchWebGPUShuffle(const PrimExpr& e) { return Call(call->dtype, T()(call->dtype, Downcast(call->op)), webgpu_args); } +void RegisterWebGPUIntrinRules() { + static bool registered = false; + if (registered) return; + registered = true; + + // clang-format off // See full list of builtin: https://www.w3.org/TR/WGSL/#builtin-functions struct ReturnAbs { @@ -194,6 +204,8 @@ TVM_REGISTER_OP("tirx.webgpu.subgroup_shuffle_down") ffi::String("webgpu.subgroup_shuffle_down"), 10) .set_attr("TGlobalSymbol", "subgroupShuffleDown") .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); + // clang-format on +} } // namespace intrin } // namespace codegen diff --git a/src/backend/webgpu/codegen/register.cc b/src/backend/webgpu/codegen/register.cc new file mode 100644 index 000000000000..49e90245023c --- /dev/null +++ b/src/backend/webgpu/codegen/register.cc @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file register.cc + * \brief WebGPU compiler backend static registration. + */ +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace backend { +namespace webgpu { + +ffi::Map UpdateWebGPUAttrs(ffi::Map target) { + bool subgroups = false; + if (target.count("supports_subgroups")) { + subgroups = Downcast(target.at("supports_subgroups"))->value != 0; + } + + if (target.count("thread_warp_size")) { + int64_t thread_warp_size = Downcast(target.at("thread_warp_size"))->value; + TVM_FFI_ICHECK(subgroups || thread_warp_size <= 1) + << "WebGPU target with thread_warp_size=" << thread_warp_size + << " requires supports_subgroups=true"; + } + + if (subgroups) { + target.Set("thread_warp_size", int64_t(32)); + } + return target; +} + +void RegisterTargetKind() { + namespace refl = tvm::ffi::reflection; + + TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) + .add_attr_option("max_num_threads", refl::DefaultValue(256)) + .add_attr_option("supports_subgroups", refl::DefaultValue(false)) + .add_attr_option("thread_warp_size", refl::DefaultValue(1)) + .set_target_canonicalizer(UpdateWebGPUAttrs) + .set_default_keys({"webgpu", "gpu"}); +} + +} // namespace webgpu +} // namespace backend + +namespace codegen { +void RegisterWebGPUCodegen(); +namespace intrin { +void RegisterWebGPUIntrinRules(); +} // namespace intrin +} // namespace codegen +} // namespace tvm + +TVM_FFI_STATIC_INIT_BLOCK() { + tvm::backend::webgpu::RegisterTargetKind(); + tvm::codegen::intrin::RegisterWebGPUIntrinRules(); + tvm::codegen::RegisterWebGPUCodegen(); +} diff --git a/src/target/webgpu/webgpu_fallback_module.cc b/src/backend/webgpu/codegen/webgpu_fallback_module.cc similarity index 99% rename from src/target/webgpu/webgpu_fallback_module.cc rename to src/backend/webgpu/codegen/webgpu_fallback_module.cc index 574c210f77e0..7646913745b5 100644 --- a/src/target/webgpu/webgpu_fallback_module.cc +++ b/src/backend/webgpu/codegen/webgpu_fallback_module.cc @@ -40,7 +40,7 @@ #include #include -#include "../../support/bytes_io.h" +#include "../../../support/bytes_io.h" namespace tvm { namespace target { diff --git a/src/target/webgpu/webgpu_fallback_module.h b/src/backend/webgpu/codegen/webgpu_fallback_module.h similarity index 98% rename from src/target/webgpu/webgpu_fallback_module.h rename to src/backend/webgpu/codegen/webgpu_fallback_module.h index 58e846c9eca5..0f9c7bf7ea5f 100644 --- a/src/target/webgpu/webgpu_fallback_module.h +++ b/src/backend/webgpu/codegen/webgpu_fallback_module.h @@ -38,8 +38,8 @@ #include -#include "../../runtime/metadata.h" -#include "../../support/env.h" +#include "../../../runtime/metadata.h" +#include "../../../support/env.h" namespace tvm { namespace target { diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index f0b27643b8e1..ccce50274311 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -78,7 +78,7 @@ #include #include -#include "../../runtime/opencl/texture.h" +#include "../../backend/opencl/runtime/texture.h" #include "utils.h" namespace tvm { diff --git a/src/runtime/extra/contrib/clml/clml_runtime.h b/src/runtime/extra/contrib/clml/clml_runtime.h index 5de3fedaaf7a..27286d5b226d 100644 --- a/src/runtime/extra/contrib/clml/clml_runtime.h +++ b/src/runtime/extra/contrib/clml/clml_runtime.h @@ -42,8 +42,8 @@ #include #include +#include "../../../../backend/opencl/runtime/opencl_common.h" #include "../../../file_utils.h" -#include "../../../opencl/opencl_common.h" #include "../../../thread_storage_scope.h" #include "../json/json_node.h" #include "../json/json_runtime.h" diff --git a/src/runtime/extra/contrib/cublas/cublas_json_runtime.cc b/src/runtime/extra/contrib/cublas/cublas_json_runtime.cc index f63d8575e05b..6520753e117b 100644 --- a/src/runtime/extra/contrib/cublas/cublas_json_runtime.cc +++ b/src/runtime/extra/contrib/cublas/cublas_json_runtime.cc @@ -32,7 +32,7 @@ #include #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" #include "../json/json_node.h" #include "../json/json_runtime.h" #include "cublas_utils.h" diff --git a/src/runtime/extra/contrib/cublas/cublas_utils.cc b/src/runtime/extra/contrib/cublas/cublas_utils.cc index 5050f20998fa..b8c239f13a85 100644 --- a/src/runtime/extra/contrib/cublas/cublas_utils.cc +++ b/src/runtime/extra/contrib/cublas/cublas_utils.cc @@ -25,7 +25,7 @@ #include #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" namespace tvm { namespace contrib { diff --git a/src/runtime/extra/contrib/cudnn/cudnn_frontend/attention.cc b/src/runtime/extra/contrib/cudnn/cudnn_frontend/attention.cc index 32f33fc739c1..f31a38bd472d 100644 --- a/src/runtime/extra/contrib/cudnn/cudnn_frontend/attention.cc +++ b/src/runtime/extra/contrib/cudnn/cudnn_frontend/attention.cc @@ -27,7 +27,7 @@ #include #include -#include "../../../../cuda/cuda_common.h" +#include "../../../../../backend/cuda/runtime/cuda_common.h" #include "../cudnn_utils.h" namespace tvm { diff --git a/src/runtime/extra/contrib/cudnn/cudnn_utils.h b/src/runtime/extra/contrib/cudnn/cudnn_utils.h index 65ee263fdc4c..81849c0f0c31 100644 --- a/src/runtime/extra/contrib/cudnn/cudnn_utils.h +++ b/src/runtime/extra/contrib/cudnn/cudnn_utils.h @@ -30,7 +30,7 @@ #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" namespace tvm { namespace contrib { diff --git a/src/runtime/extra/contrib/curand/curand.cc b/src/runtime/extra/contrib/curand/curand.cc index 5dd4f2b4aa91..5f8ff0758373 100644 --- a/src/runtime/extra/contrib/curand/curand.cc +++ b/src/runtime/extra/contrib/curand/curand.cc @@ -21,7 +21,7 @@ #include #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" #include "./helper_cuda_kernels.h" namespace tvm { diff --git a/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh b/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh index b73ab99d07ad..c30b34d0f41a 100644 --- a/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh +++ b/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm100.cuh @@ -25,7 +25,7 @@ #include #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" // clang-format off #include "cutlass/cutlass.h" diff --git a/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh b/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh index 5ab825a63995..dcb26b9071d5 100644 --- a/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh +++ b/src/runtime/extra/contrib/cutlass/fp16_group_gemm_runner_sm90.cuh @@ -25,7 +25,7 @@ #include #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" // clang-format off #include "cutlass/cutlass.h" diff --git a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh index 3f7f89ca6df7..9c1962ad98d7 100644 --- a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh +++ b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm100.cuh @@ -24,7 +24,7 @@ #include #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" // clang-format off #include "cutlass/cutlass.h" diff --git a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh index ee47a5f69283..1955e47759f7 100644 --- a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh +++ b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_gemm_runner_sm90.cuh @@ -24,7 +24,7 @@ #include #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" // clang-format off #include "cutlass/cutlass.h" diff --git a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh index 0f7ffee5defc..a516f31d3ce9 100644 --- a/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh +++ b/src/runtime/extra/contrib/cutlass/fp8_groupwise_scaled_group_gemm_runner_sm100.cuh @@ -23,7 +23,7 @@ #include #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" // clang-format off #include "cutlass/cutlass.h" diff --git a/src/runtime/extra/contrib/cutlass/gemm_runner.cuh b/src/runtime/extra/contrib/cutlass/gemm_runner.cuh index 5d876291f00e..58e1c9fbd006 100644 --- a/src/runtime/extra/contrib/cutlass/gemm_runner.cuh +++ b/src/runtime/extra/contrib/cutlass/gemm_runner.cuh @@ -25,7 +25,7 @@ #include #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" // clang-format off #include "cutlass/cutlass.h" diff --git a/src/runtime/extra/contrib/hipblas/hipblas_json_runtime.cc b/src/runtime/extra/contrib/hipblas/hipblas_json_runtime.cc index f352e184f426..1ab61fba41fb 100644 --- a/src/runtime/extra/contrib/hipblas/hipblas_json_runtime.cc +++ b/src/runtime/extra/contrib/hipblas/hipblas_json_runtime.cc @@ -32,7 +32,7 @@ #include #include -#include "../../../rocm/rocm_common.h" +#include "../../../../backend/rocm/runtime/rocm_common.h" #include "../json/json_node.h" #include "../json/json_runtime.h" #include "hipblas_utils.h" diff --git a/src/runtime/extra/contrib/hipblas/hipblas_utils.cc b/src/runtime/extra/contrib/hipblas/hipblas_utils.cc index 2ea815c676e9..1070891a4fe1 100644 --- a/src/runtime/extra/contrib/hipblas/hipblas_utils.cc +++ b/src/runtime/extra/contrib/hipblas/hipblas_utils.cc @@ -25,7 +25,7 @@ #include #include -#include "../../../rocm/rocm_common.h" +#include "../../../../backend/rocm/runtime/rocm_common.h" namespace tvm { namespace contrib { diff --git a/src/runtime/extra/contrib/nvshmem/dist_gemm.cu b/src/runtime/extra/contrib/nvshmem/dist_gemm.cu index 860cba6b5244..d975e140384d 100644 --- a/src/runtime/extra/contrib/nvshmem/dist_gemm.cu +++ b/src/runtime/extra/contrib/nvshmem/dist_gemm.cu @@ -23,7 +23,7 @@ #include #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" namespace tvm { namespace runtime { diff --git a/src/runtime/extra/contrib/nvshmem/init.cc b/src/runtime/extra/contrib/nvshmem/init.cc index b25390a613f6..698f3b68024f 100644 --- a/src/runtime/extra/contrib/nvshmem/init.cc +++ b/src/runtime/extra/contrib/nvshmem/init.cc @@ -26,7 +26,7 @@ #include #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" namespace tvm { namespace runtime { diff --git a/src/runtime/extra/contrib/nvshmem/memory_allocator.cc b/src/runtime/extra/contrib/nvshmem/memory_allocator.cc index e1806e4c4b95..ee9afd0eca3d 100644 --- a/src/runtime/extra/contrib/nvshmem/memory_allocator.cc +++ b/src/runtime/extra/contrib/nvshmem/memory_allocator.cc @@ -24,7 +24,7 @@ #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" #include "../../../memory/pooled_allocator.h" #include "../../disco/utils.h" diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_calibrator.h b/src/runtime/extra/contrib/tensorrt/tensorrt_calibrator.h index d9e8df9d38e1..4d92afb234d1 100755 --- a/src/runtime/extra/contrib/tensorrt/tensorrt_calibrator.h +++ b/src/runtime/extra/contrib/tensorrt/tensorrt_calibrator.h @@ -26,7 +26,7 @@ #include #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" #include "NvInfer.h" namespace tvm { diff --git a/src/runtime/extra/contrib/thrust/thrust.cu b/src/runtime/extra/contrib/thrust/thrust.cu index 7c3930f0c81b..b3613232311a 100644 --- a/src/runtime/extra/contrib/thrust/thrust.cu +++ b/src/runtime/extra/contrib/thrust/thrust.cu @@ -42,7 +42,7 @@ #include #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" namespace tvm { namespace contrib { diff --git a/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc b/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc index c83cba280ab7..109ce6ed37a6 100644 --- a/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc +++ b/src/runtime/extra/disco/cuda_ipc/cuda_ipc_memory.cc @@ -24,7 +24,7 @@ #include #include "../../../../../3rdparty/tensorrt_llm/custom_allreduce_kernels.h" -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" #include "../../../memory/pooled_allocator.h" #include "../nccl/nccl_context.h" diff --git a/src/runtime/extra/disco/nccl/nccl_context.h b/src/runtime/extra/disco/nccl/nccl_context.h index 1434a7c4a2e1..3747fec43487 100644 --- a/src/runtime/extra/disco/nccl/nccl_context.h +++ b/src/runtime/extra/disco/nccl/nccl_context.h @@ -36,11 +36,11 @@ #if TVM_NCCL_RCCL_SWITCH == 0 #include -#include "../../../cuda/cuda_common.h" +#include "../../../../backend/cuda/runtime/cuda_common.h" #else #include -#include "../../../rocm/rocm_common.h" +#include "../../../../backend/rocm/runtime/rocm_common.h" #endif namespace tvm { diff --git a/src/runtime/vm/attn_utils.h b/src/runtime/vm/attn_utils.h index 2ee86bb075b7..7a2c93414c0f 100644 --- a/src/runtime/vm/attn_utils.h +++ b/src/runtime/vm/attn_utils.h @@ -36,7 +36,7 @@ #include #if defined(OPENCL_ENABLE_HOST_PTR) -#include "../opencl/opencl_common.h" +#include "../../backend/opencl/runtime/opencl_common.h" #endif namespace tvm { diff --git a/src/runtime/vm/hexagon/builtin.cc b/src/runtime/vm/hexagon/builtin.cc index c7429975647f..325286f000de 100644 --- a/src/runtime/vm/hexagon/builtin.cc +++ b/src/runtime/vm/hexagon/builtin.cc @@ -27,7 +27,7 @@ #include #include -#include "../../hexagon/hexagon_device_api.h" +#include "../../../backend/hexagon/runtime/hexagon_device_api.h" namespace tvm { namespace runtime { namespace vm { diff --git a/src/s_tir/backend/adreno/inject_texture_alloc.cc b/src/s_tir/backend/adreno/inject_texture_alloc.cc index ef0fe72acd28..9b2b627dd49a 100644 --- a/src/s_tir/backend/adreno/inject_texture_alloc.cc +++ b/src/s_tir/backend/adreno/inject_texture_alloc.cc @@ -27,7 +27,7 @@ #include #include "../../../arith/ir_mutator_with_analyzer.h" -#include "../../../runtime/opencl/texture.h" +#include "../../../backend/opencl/runtime/texture.h" #include "../../../tirx/transform/ir_utils.h" namespace tvm { diff --git a/src/s_tir/backend/adreno/texture_flatten.cc b/src/s_tir/backend/adreno/texture_flatten.cc index 0ef074789652..91cdc0b6e4bf 100644 --- a/src/s_tir/backend/adreno/texture_flatten.cc +++ b/src/s_tir/backend/adreno/texture_flatten.cc @@ -33,7 +33,7 @@ #include #include "../../../arith/ir_visitor_with_analyzer.h" -#include "../../../runtime/opencl/texture.h" +#include "../../../backend/opencl/runtime/texture.h" #include "../../../runtime/thread_storage_scope.h" namespace tvm { diff --git a/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc b/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc index ac85b92dc63a..b18fa3545ff4 100644 --- a/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc +++ b/src/s_tir/meta_schedule/postproc/rewrite_cooperative_fetch.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include "../utils.h" @@ -89,8 +90,6 @@ bool ParseWarpExecutionAnn(const Schedule& sch, const Instruction& inst) { size_t GetMaxUsedDtypeBytes(SBlock block) { size_t max_bytes = 1; - static auto q_multiply_shift_per_axis = Op::Get("tirx.q_multiply_shift_per_axis"); - static auto q_multiply_shift = Op::Get("tirx.q_multiply_shift"); tirx::PostOrderVisit(block->body, [&](const ffi::ObjectRef& obj) { if (const auto* store = obj.as()) { @@ -98,7 +97,9 @@ size_t GetMaxUsedDtypeBytes(SBlock block) { } else if (const auto* load = obj.as()) { max_bytes = std::max(max_bytes, static_cast(load->dtype.bytes())); } else if (const auto* call = obj.as()) { - if (call->op.same_as(q_multiply_shift_per_axis) || call->op.same_as(q_multiply_shift)) { + static const Op& q_multiply_shift_per_axis_op = Op::Get("tirx.q_multiply_shift_per_axis"); + static const Op& q_multiply_shift_op = Op::Get("tirx.q_multiply_shift"); + if (call->op.same_as(q_multiply_shift_per_axis_op) || call->op.same_as(q_multiply_shift_op)) { // q_multiply_shift uses 64 bit multiply max_bytes = std::max(max_bytes, 8); } diff --git a/src/s_tir/schedule/analysis/analysis.cc b/src/s_tir/schedule/analysis/analysis.cc index 52e5cfe287d1..a9fca36c50fb 100644 --- a/src/s_tir/schedule/analysis/analysis.cc +++ b/src/s_tir/schedule/analysis/analysis.cc @@ -18,6 +18,7 @@ */ #include #include +#include #include #include "../ir_comparator.h" @@ -1357,8 +1358,8 @@ bool HasIfThenElse(const Stmt& stmt) { has_branch = true; } else if (const auto* call = obj.as()) { // Case 3: Call the `if_then_else` operator - static const Op& op_if_then_else = Op::Get("tirx.if_then_else"); - if (call->op.same_as(op_if_then_else)) { + static const Op& if_then_else_op = Op::Get("tirx.if_then_else"); + if (call->op.same_as(if_then_else_op)) { has_branch = true; } } diff --git a/src/s_tir/transform/inject_permuted_layout.cc b/src/s_tir/transform/inject_permuted_layout.cc index 74e843a6e5d0..8ef5051ae0d5 100644 --- a/src/s_tir/transform/inject_permuted_layout.cc +++ b/src/s_tir/transform/inject_permuted_layout.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -246,17 +247,6 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { return access_ptr_call; } - // Device intrinsics are registered under both a flat name (the builtin Op) - // and a canonical dotted name (emitted by TVMScript and the tensor - // intrinsics), so compare against both. - static bool IsOp(const Call& call, const Op& compat_op, const char* canonical_name) { - if (call->op.same_as(compat_op)) { - return true; - } - const auto* op_node = call->op.as(); - return op_node != nullptr && op_node->name == canonical_name; - } - PrimExpr VisitExpr_(const CallNode* op) final { // Rewrite from/to shared or shared.dyn to/from local auto call = Downcast(IRMutatorWithAnalyzer::VisitExpr_(op)); @@ -265,11 +255,13 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { return call; } - // Only the legacy intrinsic forms fold the shared memory access into a - // tvm_access_ptr + offset, which must be rewritten here. The non-legacy - // forms address shared memory through BufferLoad (e.g. via address_of), - // which is already handled by the BufferLoad visitor above. - if (IsOp(call, builtin::ptx_ldmatrix_legacy(), "tirx.ptx.ldmatrix_legacy")) { + static const Op& ptx_ldmatrix_op = Op::Get("tirx.ptx.ldmatrix_legacy"); + static const Op& mma_store_op = Op::Get("tirx.mma_store_legacy"); + if (!call->op.same_as(ptx_ldmatrix_op) && !call->op.same_as(mma_store_op)) { + return call; + } + + if (call->op.same_as(ptx_ldmatrix_op)) { // form: T.ptx.ldmatrix_legacy(..., smem_ptr, smem_offset) // smem_ptr: T.tvm_access_ptr(ptype, data, offset, extent, rw_mask) auto access_ptr = call->args[5]; @@ -279,7 +271,7 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { new_call->args.Set(5, new_access_ptr); new_call->args.Set(6, IntImm(smem_offset->dtype, 0)); return call; - } else if (IsOp(call, builtin::mma_store_legacy(), "tirx.mma_store_legacy")) { + } else if (call->op.same_as(mma_store_op)) { // TODO(yixin): mma_store is not fully tested yet // because we will directly store result to Buffer instead of calling mma_store now auto access_ptr = call->args[2]; @@ -287,8 +279,9 @@ class PermutedLayoutInjector : private IRMutatorWithAnalyzer { auto new_call = call.CopyOnWrite(); new_call->args.Set(2, new_access_ptr); return call; + } else { + TVM_FFI_THROW(InternalError) << "Invalid call node: " << call; } - return call; } static constexpr size_t VECTORIZE_FACTOR = 8; diff --git a/src/s_tir/transform/inject_ptx_async_copy.cc b/src/s_tir/transform/inject_ptx_async_copy.cc index 3c84a021b551..3a0f113499f8 100644 --- a/src/s_tir/transform/inject_ptx_async_copy.cc +++ b/src/s_tir/transform/inject_ptx_async_copy.cc @@ -22,6 +22,7 @@ * \file inject_ptx_async_copy.cc */ #include +#include #include #include #include @@ -89,7 +90,8 @@ class PTXAsyncCopyInjector : public StmtMutator { if (predicated) { args.push_back(predicate_value); } - return Evaluate(Call(store->buffer->dtype, tvm::tirx::builtin::ptx_cp_async(), args)); + static const Op& ptx_cp_async_op = Op::Get("tirx.ptx_cp_async"); + return Evaluate(Call(store->buffer->dtype, ptx_cp_async_op, args)); } // Predicated load don't support vectorized indexing. @@ -117,7 +119,8 @@ class PTXAsyncCopyInjector : public StmtMutator { return PrimExpr(); }(); if (src_offset.defined() && dst_offset.defined()) { - return Evaluate(Call(store->buffer->dtype, tvm::tirx::builtin::ptx_cp_async(), + static const Op& ptx_cp_async_op = Op::Get("tirx.ptx_cp_async"); + return Evaluate(Call(store->buffer->dtype, ptx_cp_async_op, {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)})); } @@ -146,8 +149,9 @@ class PTXAsyncCopyInjector : public StmtMutator { }(); if (src_offset.defined() && dst_offset.defined()) { + static const Op& ptx_cp_async_op = Op::Get("tirx.ptx_cp_async"); return Evaluate( - Call(store->buffer->dtype, tvm::tirx::builtin::ptx_cp_async(), + Call(store->buffer->dtype, ptx_cp_async_op, {store->buffer->data, mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes), predicate_value})); } diff --git a/src/s_tir/transform/inject_ptx_ldg32.cc b/src/s_tir/transform/inject_ptx_ldg32.cc index f02b253b29e8..9dfee105544b 100644 --- a/src/s_tir/transform/inject_ptx_ldg32.cc +++ b/src/s_tir/transform/inject_ptx_ldg32.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -97,7 +98,8 @@ class PTXRewriter : public StmtMutator { new_predicate = BufferLoad(predicate_buffer, {IntImm(DataType::Int(32), 0)}); new_indice = BufferLoad(addr_buffer, {IntImm(DataType::Int(32), 1)}); BufferStore value_store(store->buffer, imm_value, {new_indice}); - Evaluate ptx_load(Call(store->buffer->dtype, tvm::tirx::builtin::ptx_ldg32(), + static const Op& ptx_ldg32_op = Op::Get("tirx.ptx.ldg32"); + Evaluate ptx_load(Call(store->buffer->dtype, ptx_ldg32_op, {store->buffer->data, new_predicate, new_lhs, new_indice})); ffi::Array tmp_seq = {addr_store, local_addr_store, predicate_store, value_store, ptx_load}; diff --git a/src/s_tir/transform/inject_software_pipeline.cc b/src/s_tir/transform/inject_software_pipeline.cc index f5190bddbce4..8064f7b16475 100644 --- a/src/s_tir/transform/inject_software_pipeline.cc +++ b/src/s_tir/transform/inject_software_pipeline.cc @@ -24,11 +24,11 @@ #include #include #include +#include #include #include #include #include -#include #include #include @@ -43,14 +43,6 @@ using namespace tvm::tirx; namespace software_pipeline { -static bool IsOp(const Call& call, const Op& compat_op, const char* canonical_name) { - if (call->op.same_as(compat_op)) { - return true; - } - const auto* op_node = call->op.as(); - return op_node != nullptr && op_node->name == canonical_name; -} - /*! * \brief Create a block and infer the access region with the given body. * @@ -115,12 +107,12 @@ class PipelineOpaqueAccessRewriter { PrimExpr Rewrite(const Call& call) { // Intrinsic calls should be handled explicitly here as they are opaque accesses to // buffer. - static const auto& load_matrix_sync = builtin::tvm_load_matrix_sync(); - static const auto& store_matrix_sync = builtin::tvm_store_matrix_sync(); - static const auto& mma_sync = builtin::tvm_mma_sync(); static const auto& access_ptr = builtin::tvm_access_ptr(); - static const auto& ptx_ldmatrix_legacy = builtin::ptx_ldmatrix_legacy(); - static const auto& ptx_mma_legacy = builtin::ptx_mma_legacy(); + static const Op& load_matrix_sync = Op::Get("tirx.tvm_load_matrix_sync"); + static const Op& store_matrix_sync = Op::Get("tirx.tvm_store_matrix_sync"); + static const Op& mma_sync = Op::Get("tirx.tvm_mma_sync"); + static const Op& ptx_ldmatrix_legacy = Op::Get("tirx.ptx.ldmatrix_legacy"); + static const Op& ptx_mma_legacy = Op::Get("tirx.ptx.mma_legacy"); if (call->op.same_as(load_matrix_sync) || call->op.same_as(store_matrix_sync)) { const Buffer& buffer = buffer_data_to_buffer_.at(Downcast(call->args[0])); auto it = buffer_remap_.find(buffer); @@ -145,9 +137,9 @@ class PipelineOpaqueAccessRewriter { return Call(call->dtype, call->op, new_args, call->attrs, call->span); } else if (call->op.same_as(access_ptr)) { return RewriteBufferAccess(call, {1}); - } else if (IsOp(call, ptx_mma_legacy, "tirx.ptx.mma_legacy")) { + } else if (call->op.same_as(ptx_mma_legacy)) { return RewriteBufferAccess(call, {6, 8, 10}); - } else if (IsOp(call, ptx_ldmatrix_legacy, "tirx.ptx.ldmatrix_legacy")) { + } else if (call->op.same_as(ptx_ldmatrix_legacy)) { return RewriteBufferAccess(call, {3}); } return call; diff --git a/src/s_tir/transform/memhammer_lower_auto_copy.cc b/src/s_tir/transform/memhammer_lower_auto_copy.cc index af805d64f7eb..857b94266196 100644 --- a/src/s_tir/transform/memhammer_lower_auto_copy.cc +++ b/src/s_tir/transform/memhammer_lower_auto_copy.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -568,8 +569,9 @@ class AutoPadder { void VisitStmt_(const SBlockNode* op) final { if (const auto* eval = op->body.as()) { if (const auto* call = eval->value.as()) { - if (call->op == builtin::tvm_load_matrix_sync() || - call->op == builtin::tvm_store_matrix_sync()) { + static const Op& tvm_load_matrix_sync_op = Op::Get("tirx.tvm_load_matrix_sync"); + static const Op& tvm_store_matrix_sync_op = Op::Get("tirx.tvm_store_matrix_sync"); + if (call->op == tvm_load_matrix_sync_op || call->op == tvm_store_matrix_sync_op) { for (const MatchBufferRegion& r : op->match_buffers) { Buffer src_buffer = r->source->buffer; runtime::StorageScope scope = runtime::StorageScope::Create(src_buffer.scope()); diff --git a/src/s_tir/transform/memhammer_tensorcore_rewrite.cc b/src/s_tir/transform/memhammer_tensorcore_rewrite.cc index 5a3b48521873..a8325c4aca18 100644 --- a/src/s_tir/transform/memhammer_tensorcore_rewrite.cc +++ b/src/s_tir/transform/memhammer_tensorcore_rewrite.cc @@ -18,6 +18,7 @@ */ #include +#include #include "./memhammer_rewrite_rule.h" @@ -148,6 +149,7 @@ Stmt RewriteWmmaLoad(Stmt stmt) { /*buffer_type=*/kDefault); ffi::Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); ffi::Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + static const Op& tvm_load_matrix_sync_op = Op::Get("tirx.tvm_load_matrix_sync"); Stmt wmma_body = SBlockRealize( /*iter_values=*/{}, /*predicate=*/const_true(), @@ -159,7 +161,7 @@ Stmt RewriteWmmaLoad(Stmt stmt) { /*body=*/ Evaluate(Call( /*data=*/runtime::DataType::Handle(), - /*op=*/builtin::tvm_load_matrix_sync(), + /*op=*/tvm_load_matrix_sync_op, { /*0:*/ new_tgt_buffer->data, /*1:*/ 16, @@ -257,6 +259,7 @@ Stmt RewriteWmmaStore(Stmt stmt) { ffi::Array read_region = RelaxIndices(buf_load->indices, src_buffer->shape, var_dom); ffi::Array write_region = RelaxIndices(buf_store->indices, tgt_buffer->shape, var_dom); + static const Op& tvm_store_matrix_sync_op = Op::Get("tirx.tvm_store_matrix_sync"); Stmt wmma_body = SBlockRealize( /*iter_values=*/{}, // /*predicate=*/const_true(), @@ -266,7 +269,7 @@ Stmt RewriteWmmaStore(Stmt stmt) { /*name_hint=*/"wmma_store", Evaluate(Call( /*data=*/runtime::DataType::Handle(), - /*op=*/builtin::tvm_store_matrix_sync(), + /*op=*/tvm_store_matrix_sync_op, {/*0:*/ new_src_buffer->data, /*1:*/ 16, /*2:*/ 16, diff --git a/src/s_tir/transform/merge_shared_memory_allocations.cc b/src/s_tir/transform/merge_shared_memory_allocations.cc index d1417943c327..4b61c8994c97 100644 --- a/src/s_tir/transform/merge_shared_memory_allocations.cc +++ b/src/s_tir/transform/merge_shared_memory_allocations.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -486,6 +487,7 @@ class SharedMemoryRewriter : public StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) final { + static const Op& ptx_cp_async_op = Op::Get("tirx.ptx_cp_async"); if (op->op.same_as(builtin::tvm_access_ptr())) { TVM_FFI_ICHECK_EQ(op->args.size(), 5U); DataType dtype = op->args[0].dtype(); @@ -501,7 +503,7 @@ class SharedMemoryRewriter : public StmtExprMutator { return Call(op->dtype, op->op, {op->args[0], scope_stack_.back().merged_buf_var, extra_offset + offset, extent, op->args[4]}); - } else if (op->op.same_as(builtin::ptx_cp_async())) { + } else if (op->op.same_as(ptx_cp_async_op)) { TVM_FFI_ICHECK((op->args.size() == 5U) || (op->args.size() == 6U)); Var buffer = Downcast(op->args[0]); const auto* ptr_type = buffer->type_annotation.as(); diff --git a/src/s_tir/transform/tensorcore_infer_fragment.cc b/src/s_tir/transform/tensorcore_infer_fragment.cc index 7b9bcdf80443..860cdc2cc006 100644 --- a/src/s_tir/transform/tensorcore_infer_fragment.cc +++ b/src/s_tir/transform/tensorcore_infer_fragment.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -46,8 +47,10 @@ class FragmentGetter : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { StmtExprVisitor::VisitExpr_(op); - if (op->op.same_as(builtin::tvm_load_matrix_sync()) || - op->op.same_as(builtin::tvm_store_matrix_sync())) { + static const Op& tvm_load_matrix_sync_op = Op::Get("tirx.tvm_load_matrix_sync"); + static const Op& tvm_store_matrix_sync_op = Op::Get("tirx.tvm_store_matrix_sync"); + static const Op& tvm_fill_fragment_op = Op::Get("tirx.tvm_fill_fragment"); + if (op->op.same_as(tvm_load_matrix_sync_op) || op->op.same_as(tvm_store_matrix_sync_op)) { // Get shape and layout information from load and store intrinsic TVM_FFI_ICHECK_EQ(op->args.size(), 8U); const VarNode* buffer_var = op->args[0].as(); @@ -82,7 +85,7 @@ class FragmentGetter : public StmtExprVisitor { } fragments[buffer_var] = info; } - } else if (op->op.same_as(builtin::tvm_fill_fragment())) { + } else if (op->op.same_as(tvm_fill_fragment_op)) { // Get shape information from fill intrinsic TVM_FFI_ICHECK_EQ(op->args.size(), 6U); const VarNode* buffer_var = op->args[0].as(); @@ -136,7 +139,9 @@ class FragmentChecker : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { StmtExprVisitor::VisitExpr_(op); // Check shape when calling tvm_mma_sync - if (op->op.same_as(builtin::tvm_mma_sync()) || op->op.same_as(builtin::tvm_bmma_sync())) { + static const Op& tvm_mma_sync_op = Op::Get("tirx.tvm_mma_sync"); + static const Op& tvm_bmma_sync_op = Op::Get("tirx.tvm_bmma_sync"); + if (op->op.same_as(tvm_mma_sync_op) || op->op.same_as(tvm_bmma_sync_op)) { TVM_FFI_ICHECK_EQ(op->args.size(), 8U); const VarNode* buffer_var_d = op->args[0].as(); const VarNode* buffer_var_a = op->args[2].as(); diff --git a/src/target/llvm/intrin_rule_nvptx.cc b/src/target/llvm/intrin_rule_nvptx.cc index 0707a9a78771..d8706a94b181 100644 --- a/src/target/llvm/intrin_rule_nvptx.cc +++ b/src/target/llvm/intrin_rule_nvptx.cc @@ -23,6 +23,7 @@ #ifdef TVM_LLVM_VERSION #include +#include #include #include #include @@ -71,7 +72,7 @@ TVM_REGISTER_OP("tirx.round") using namespace tirx; const CallNode* call = e.as(); TVM_FFI_ICHECK(call != nullptr); - auto nearbyint_op = Op::Get("tirx.nearbyint"); + static const Op& nearbyint_op = Op::Get("tirx.nearbyint"); auto new_call = Call(call->dtype, nearbyint_op, call->args); return DispatchPureExternLibDevice(new_call); }); diff --git a/src/target/tag.cc b/src/target/tag.cc index e0374e831194..74fa65b0e627 100644 --- a/src/target/tag.cc +++ b/src/target/tag.cc @@ -82,17 +82,4 @@ Target TargetTag::AddTag(ffi::String name, ffi::Map confi return Target(config); } -/********** Register Trainium target tags **********/ - -#define TVM_REGISTER_TAG_AWS_TRN1(Name, Cores) \ - TVM_REGISTER_TARGET_TAG(Name).set_config({{"kind", ffi::String("trn")}, \ - {"num-cores", Cores}, \ - {"partition_size", 128}, \ - {"max_sbuf_size_per_partition", 196608}, \ - {"max_psum_size_per_partition", 16384}}); - -TVM_REGISTER_TAG_AWS_TRN1("aws/trn1/trn1.2xlarge", 2); -TVM_REGISTER_TAG_AWS_TRN1("aws/trn1/trn1.32xlarge", 32); -#undef TVM_REGISTER_TAG_AWS_TRN1 - } // namespace tvm diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index cbad63fdaf18..68f1f2e3b7c8 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -101,198 +101,6 @@ ffi::Optional TargetKind::Get(const ffi::String& target_kind_name) { return reg->kind_; } -/********** Utility functions **********/ - -/*! - * \brief Extract a string from the string with the given prefix. - * For example, when `str` is "sm_20" and `prefix` is "sm_". - * This function first checks if `str` starts with `prefix`, - * then return the integer 20 after the `prefix` - * \param str The string to be extracted - * \param prefix The prefix to be checked - * \return A string, the extracted string. "" if the check fails - */ -std::string ExtractStringWithPrefix(const std::string& str, const std::string& prefix) { - if (str.find(prefix) != 0) return ""; - std::size_t pos = prefix.length(); - while (pos < str.length() && (std::isdigit(str[pos]) || std::isalpha(str[pos]))) { - ++pos; - } - return str.substr(prefix.length(), pos - prefix.length()); -} - -/*! - * \brief Using TVM DeviceAPI to detect the device flag - * \param device The device to be detected - * \param flag The device flag to be detected - * \param val The detected value - * \return A boolean indicating if detection succeeds - */ -static bool DetectDeviceFlag(Device device, runtime::DeviceAttrKind flag, ffi::Any* val) { - using runtime::DeviceAPI; - DeviceAPI* api = DeviceAPI::Get(device, true); - // Check if compiled with the corresponding device api - if (api == nullptr) { - return false; - } - // Check if the device exists - api->GetAttr(device, runtime::kExist, val); - int exists = val->cast(); - if (!exists) { - return false; - } - // Get the arch of the device - DeviceAPI::Get(device)->GetAttr(device, flag, val); - return true; -} - -void CheckOrSetAttr(ffi::Map* attrs, const ffi::String& name, - const ffi::String& value) { - auto iter = attrs->find(name); - if (iter == attrs->end()) { - attrs->Set(name, value); - } else { - auto str = (*iter).second.try_cast(); - TVM_FFI_CHECK(str && str.value() == value, ValueError) - << "Expects \"" << name << "\" to be \"" << value << "\", but gets: " << (*iter).second; - } -} - -/********** Target kind attribute updaters **********/ - -/*! - * \brief Update the attributes in the CUDA target. - * \param target The Target to update - * \return The updated attributes - */ -ffi::Map UpdateCUDAAttrs(ffi::Map target) { - // Update -arch=sm_xx - if (target.count("arch")) { - // If -arch has been specified, validate the correctness - ffi::String archStr = Downcast(target.at("arch")); - TVM_FFI_CHECK(support::StartsWith(archStr, "sm_"), ValueError) - << "CUDA target gets an invalid CUDA arch: -arch=" << archStr; - } else { - // Use the compute version of the first CUDA GPU instead - int archInt; - ffi::Any version; - if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) { - LOG(WARNING) << "Unable to detect CUDA version, default to \"-arch=sm_50\" instead"; - archInt = 50; - } else { - archInt = std::stod(version.cast()) * 10 + 0.1; - } - if (archInt >= 90) { - target.Set("arch", ffi::String("sm_") + std::to_string(archInt) + "a"); - } else { - target.Set("arch", ffi::String("sm_") + std::to_string(archInt)); - } - } - return target; -} - -/*! - * \brief Update the attributes in the LLVM NVPTX target. - * \param target The Target to update - * \return The updated attributes - */ -ffi::Map UpdateNVPTXAttrs(ffi::Map target) { - CheckOrSetAttr(&target, "mtriple", "nvptx64-nvidia-cuda"); - // Update -mcpu=sm_xx - if (target.count("mcpu")) { - // If -mcpu has been specified, validate the correctness - ffi::String mcpu = Downcast(target.at("mcpu")); - TVM_FFI_CHECK(support::StartsWith(mcpu, "sm_"), ValueError) - << "NVPTX target gets an invalid CUDA arch: -mcpu=" << mcpu; - } else { - // Use the compute version of the first CUDA GPU instead - int arch; - ffi::Any version; - if (!DetectDeviceFlag({kDLCUDA, 0}, runtime::kComputeVersion, &version)) { - LOG(WARNING) << "Unable to detect CUDA version, default to \"-mcpu=sm_50\" instead"; - arch = 50; - } else { - arch = std::stod(version.cast()) * 10 + 0.1; - } - target.Set("mcpu", ffi::String("sm_") + std::to_string(arch)); - } - return target; -} - -/*! - * \brief Update the attributes in the LLVM ROCm target. - * \param target The Target to update - * \return The updated attributes - */ -ffi::Map UpdateROCmAttrs(ffi::Map target) { - CheckOrSetAttr(&target, "mtriple", "amdgcn-amd-amdhsa-hcc"); - // Update -mcpu=gfx - std::string arch = "gfx900"; - if (target.count("mcpu")) { - ffi::String mcpu = Downcast(target.at("mcpu")); - arch = ExtractStringWithPrefix(mcpu, "gfx"); - TVM_FFI_CHECK(!arch.empty(), ValueError) - << "ROCm target gets an invalid GFX version: -mcpu=" << mcpu; - } else { - ffi::Any val; - if (const auto f_get_rocm_arch = tvm::ffi::Function::GetGlobal("tvm_callback_rocm_get_arch")) { - arch = (*f_get_rocm_arch)().cast(); - } - target.Set("mcpu", ffi::String(arch)); - } - // Update -mattr before ROCm 3.5: - // Before ROCm 3.5 we needed code object v2, starting - // with 3.5 we need v3 (this argument disables v3) - - ffi::Any val; - int version; - if (!DetectDeviceFlag({kDLROCM, 0}, runtime::kApiVersion, &val)) { - LOG(WARNING) << "Unable to detect ROCm version, assuming >= 3.5"; - version = 305; - } else { - version = val.cast(); - } - if (version < 305) { - ffi::Array mattr; - if (target.count("mattr")) { - mattr = Downcast>(target.at("mattr")); - } - mattr.push_back("-code-object-v3"); - target.Set("mattr", mattr); - } - return target; -} - -/*! - * \brief Update WebGPU target attributes for subgroup-enabled lowering. - * Runtime routing on the WebLLM side guarantees subgroup size == 32. - * Runtime routing on the WebLLM side guarantees - * maxComputeInvocationsPerWorkgroup >= 1024. - * This is intentionally constrained for the subgroup-enabled WASM variant. - * When supports_subgroups is true, canonicalize thread_warp_size to 32 so - * TIR lowering can emit subgroup shuffle reductions. - * \param target The Target to update - * \return The updated attributes - */ -ffi::Map UpdateWebGPUAttrs(ffi::Map target) { - bool subgroups = false; - if (target.count("supports_subgroups")) { - subgroups = Downcast(target.at("supports_subgroups"))->value != 0; - } - - if (target.count("thread_warp_size")) { - int64_t thread_warp_size = Downcast(target.at("thread_warp_size"))->value; - TVM_FFI_ICHECK(subgroups || thread_warp_size <= 1) - << "WebGPU target with thread_warp_size=" << thread_warp_size - << " requires supports_subgroups=true"; - } - - if (subgroups) { - target.Set("thread_warp_size", int64_t(32)); - } - return target; -} - /*! * \brief Test Target Parser * \param target The Target to update @@ -362,126 +170,6 @@ TVM_REGISTER_TARGET_KIND("c", kDLCPU) .set_default_keys({"cpu"}) .set_target_canonicalizer(tvm::target::canonicalizer::llvm::Canonicalize); -TVM_REGISTER_TARGET_KIND("cuda", kDLCUDA) - .add_attr_option("mcpu") - .add_attr_option("arch") - .add_attr_option("max_shared_memory_per_block") - .add_attr_option("max_threads_per_block") - .add_attr_option("thread_warp_size", refl::DefaultValue(32)) - .add_attr_option("registers_per_block") - .add_attr_option("l2_cache_size_bytes") - .add_attr_option("max_num_threads", - refl::DefaultValue(1024)) // TODO(@zxybazh): deprecate it - .set_default_keys({"cuda", "gpu"}) - .set_target_canonicalizer(UpdateCUDAAttrs); - -TVM_REGISTER_TARGET_KIND("nvptx", kDLCUDA) - .add_attr_option("mcpu") - .add_attr_option("mtriple") - .add_attr_option("max_num_threads", refl::DefaultValue(1024)) - .add_attr_option("thread_warp_size", refl::DefaultValue(32)) - .set_default_keys({"cuda", "gpu"}) - .set_target_canonicalizer(UpdateNVPTXAttrs); - -TVM_REGISTER_TARGET_KIND("rocm", kDLROCM) - .add_attr_option("mcpu") - .add_attr_option("mtriple") - .add_attr_option>("mattr") - // TODO(masahi): Support querying from a target device - // On RDNA cards, thread_warp_size should be 32 - .add_attr_option("max_num_threads", refl::DefaultValue(256)) - .add_attr_option("max_threads_per_block", refl::DefaultValue(256)) - .add_attr_option("max_shared_memory_per_block", refl::DefaultValue(65536)) - .add_attr_option("thread_warp_size", refl::DefaultValue(64)) - .set_default_keys({"rocm", "gpu"}) - .set_target_canonicalizer(UpdateROCmAttrs); - -TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL) - .add_attr_option("max_threads_per_block", refl::DefaultValue(256)) - .add_attr_option("max_shared_memory_per_block", refl::DefaultValue(16384)) - .add_attr_option("max_num_threads", refl::DefaultValue(256)) - .add_attr_option("thread_warp_size", refl::DefaultValue(1)) - .add_attr_option("texture_spatial_limit", refl::DefaultValue(16384)) - .add_attr_option("texture_depth_limit", refl::DefaultValue(2048)) - // Faced that Qualcomm OpenCL runtime crashed without any error message in - // the case when the number of kernel arguments was pretty big. OpenCL doesn't - // specify any limitations on the number of kernel arguments. max_function_args - // equals to 128 looks like a reasonable number of kernel arguments. - .add_attr_option("max_function_args", refl::DefaultValue(128)) - .add_attr_option("image_base_address_alignment", refl::DefaultValue(64)) - .set_default_keys({"opencl", "gpu"}); - -// The metal has some limitations on the number of input parameters. This is why attribute -// `max_function_args` was introduced. It specifies the maximum number of kernel argumetns. More -// information about this limitation can be found here: -// https://developer.apple.com/documentation/metal/buffers/about_argument_buffers?language=objc -// See also https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf -TVM_REGISTER_TARGET_KIND("metal", kDLMetal) - .add_attr_option("max_num_threads", refl::DefaultValue(256)) - .add_attr_option("max_threads_per_block", refl::DefaultValue(256)) - .add_attr_option("max_shared_memory_per_block", refl::DefaultValue(32768)) - .add_attr_option("thread_warp_size", refl::DefaultValue(16)) - .add_attr_option("max_function_args", refl::DefaultValue(31)) - .set_default_keys({"metal", "gpu"}); - -TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) - .add_attr_option>("mattr") - // Feature support - .add_attr_option("supports_float16") - .add_attr_option("supports_float32", refl::DefaultValue(true)) - .add_attr_option("supports_float64") - .add_attr_option("supports_int8") - .add_attr_option("supports_int16") - .add_attr_option("supports_int32", refl::DefaultValue(true)) - .add_attr_option("supports_int64") - .add_attr_option("supports_8bit_buffer") - .add_attr_option("supports_16bit_buffer") - .add_attr_option("supports_storage_buffer_storage_class") - .add_attr_option("supports_push_descriptor") - .add_attr_option("supports_dedicated_allocation") - .add_attr_option("supports_integer_dot_product") - .add_attr_option("supports_cooperative_matrix") - .add_attr_option("supported_subgroup_operations") - // Physical device limits - .add_attr_option("max_num_threads", refl::DefaultValue(256)) - .add_attr_option("max_threads_per_block", refl::DefaultValue(256)) - .add_attr_option("thread_warp_size", refl::DefaultValue(1)) - .add_attr_option("max_block_size_x") - .add_attr_option("max_block_size_y") - .add_attr_option("max_block_size_z") - .add_attr_option("max_push_constants_size") - .add_attr_option("max_uniform_buffer_range") - .add_attr_option("max_storage_buffer_range") - .add_attr_option("max_per_stage_descriptor_storage_buffer") - .add_attr_option("max_shared_memory_per_block") - // Other device properties - .add_attr_option("device_type") - .add_attr_option("device_name") - .add_attr_option("driver_name") - .add_attr_option("driver_version") - .add_attr_option("vulkan_api_version") - .add_attr_option("max_spirv_version") - // Tags - .set_default_keys({"vulkan", "gpu"}); - -TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU) - .add_attr_option("max_num_threads", refl::DefaultValue(256)) - .add_attr_option("supports_subgroups", refl::DefaultValue(false)) - // thread_warp_size=1: is_subwarp_reduction and is_multiwarp_reduction returns false, so no - // subgroup ops are emitted. - .add_attr_option("thread_warp_size", refl::DefaultValue(1)) - .set_target_canonicalizer(UpdateWebGPUAttrs) - .set_default_keys({"webgpu", "gpu"}); - -TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) - .add_attr_option>("mattr") - .add_attr_option("mcpu") - .add_attr_option("mtriple") - .add_attr_option>("llvm-options") - .add_attr_option("num-cores") - .add_attr_option("vtcm-capacity") - .set_default_keys({"hexagon", "cpu"}); - TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev); TVM_REGISTER_TARGET_KIND("composite", kDLCPU) // line break @@ -512,12 +200,6 @@ TVM_REGISTER_TARGET_KIND("composite", kDLCPU) // line break TVM_REGISTER_TARGET_KIND("test", kDLCPU) // line break .set_target_canonicalizer(TestTargetParser); -TVM_REGISTER_TARGET_KIND("trn", DLDeviceType::kDLTrn) // line break - .add_attr_option("partition_size", 128) - .add_attr_option("max_sbuf_size_per_partition", 196608) - .add_attr_option("max_psum_size_per_partition", 16384) - .add_attr_option("num-cores"); - /********** Registry **********/ TVM_FFI_STATIC_INIT_BLOCK() { diff --git a/src/tirx/analysis/filter_canonical.cc b/src/tirx/analysis/filter_canonical.cc index fbf098cced98..dfefdd51c043 100644 --- a/src/tirx/analysis/filter_canonical.cc +++ b/src/tirx/analysis/filter_canonical.cc @@ -27,9 +27,9 @@ #include #include +#include #include #include -#include namespace tvm { namespace tirx { @@ -45,7 +45,8 @@ bool IsBitwiseAndCall(const CallNode* call) { } bool IsPtxElectSyncCall(const CallNode* call) { - if (call->op.same_as(tirx::builtin::ptx_elect_sync())) return true; + static const Op& ptx_elect_sync_op = Op::Get("tirx.ptx_elect_sync"); + if (call->op.same_as(ptx_elect_sync_op)) return true; if (auto op = call->op.as()) { return op.value()->name == "tirx.ptx.elect_sync"; } diff --git a/src/tirx/ir/data_type_rewriter.cc b/src/tirx/ir/data_type_rewriter.cc index 9b030c560e09..f92969781bdd 100644 --- a/src/tirx/ir/data_type_rewriter.cc +++ b/src/tirx/ir/data_type_rewriter.cc @@ -25,6 +25,7 @@ #include "data_type_rewriter.h" #include +#include #include #include #include @@ -232,7 +233,6 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { Call before = ffi::GetRef(op); PrimExpr e = StmtExprMutator::VisitExpr_(op); op = e.as(); - static const Op& builtin_pow_ = Op::Get("tirx.pow"); TVM_FFI_ICHECK(op != nullptr) << "Expected type to be CallNode" << ", but get " << e->GetTypeKey(); if (op->op.same_as(builtin::shift_right())) { @@ -245,11 +245,14 @@ PrimExpr DataTypeLegalizer::VisitExpr_(const CallNode* op) { return op->args[0] | op->args[1]; } else if (op->op.same_as(builtin::bitwise_xor())) { return op->args[0] ^ op->args[1]; - } else if (op->op.same_as(builtin_pow_)) { + } + static const Op& pow_op = Op::Get("tirx.pow"); + static const Op& clz_op = Op::Get("tirx.clz"); + if (op->op.same_as(pow_op)) { return pow(op->args[0], op->args[1]); } else if (op->op.same_as(builtin::if_then_else())) { return Call(op->dtype, op->op, {op->args[0], op->args[1], op->args[2]}, op->attrs, op->span); - } else if (op->op.same_as(Op::Get("tirx.clz"))) { + } else if (op->op.same_as(clz_op)) { DataType before_dtype = before->args[0]->dtype; DataType after_dtype = op->args[0]->dtype; TVM_FFI_ICHECK((before_dtype.is_int() || before_dtype.is_uint()) && diff --git a/src/tirx/ir/exec_scope.cc b/src/tirx/ir/exec_scope.cc index c885c0251134..1f672f8eb9d2 100644 --- a/src/tirx/ir/exec_scope.cc +++ b/src/tirx/ir/exec_scope.cc @@ -17,6 +17,7 @@ * under the License. */ #include +#include #include #include #include @@ -384,10 +385,11 @@ ffi::Array ResolveCuda(ScopeBinding binding, case ScopeBinding::kKernelCluster: { TVM_FFI_ICHECK_LE(out_dim, 3) << "ValueError: kernel->cluster can only have 3 dimensions for now"; + static const Op& ptx_fetch_register_op = Op::Get("tirx.ptx.fetch_register"); ffi::Array ret; for (int i = 0; i < out_dim; ++i) { ret.push_back(tirx::Call( - DataType::Int(32), builtin::ptx_fetch_register(), + DataType::Int(32), ptx_fetch_register_op, {IntImm(DataType::Int(32), 32), StringImm("clusterid." + std::string(1, 'x' + i))})); } return ret; diff --git a/src/tirx/ir/stmt.cc b/src/tirx/ir/stmt.cc index fa40ffc3894c..adb0bcda3eeb 100644 --- a/src/tirx/ir/stmt.cc +++ b/src/tirx/ir/stmt.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -670,8 +671,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { } PrimExpr TypeAnnotation(DataType dtype, Span span) { - static auto op = Op::Get("tirx.type_annotation"); - return tirx::Call(dtype, op, {}, {}, span); + static const Op& type_annotation_op = Op::Get("tirx.type_annotation"); + return tirx::Call(dtype, type_annotation_op, {}, {}, span); } TVM_TIRX_REGISTER_OP("type_annotation") diff --git a/src/tirx/op/builtin.cc b/src/tirx/op/builtin.cc index c2ad5559d608..4a16e11139f1 100644 --- a/src/tirx/op/builtin.cc +++ b/src/tirx/op/builtin.cc @@ -290,19 +290,6 @@ TIR_DEFINE_BUILTIN_FUNC(tvm_global_barrier_kinit) TIR_DEFINE_BUILTIN_FUNC(tvm_thread_allreduce) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(make_filled_simdgroup_matrix) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(simdgroup_load) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(simdgroup_store) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(simdgroup_multiply_accumulate) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - TIR_DEFINE_BUILTIN_FUNC(cooperative_tensor_fill) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); @@ -409,146 +396,6 @@ TIR_DEFINE_BUILTIN_FUNC(buffer_offset) TIR_DEFINE_BUILTIN_FUNC(print_buffer) .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(timer_init_cuda) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(timer_start_cuda) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(timer_end_cuda) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(timer_finalize_cuda) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_atomic_add) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_thread_fence) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_warpgroup_sync) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_warp_reduce) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_cta_reduce) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_copy_bytes) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_warp_sync) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_cta_sync) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_grid_sync) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_thread_rank) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kPure)); - -// Cluster-wide sync (CUDA thread block clusters) -TIR_DEFINE_BUILTIN_FUNC(cuda_cluster_sync) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_half2float) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_bfloat162float) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_float22half2) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_trap_when_assert_failed) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_runtime_instr_desc) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_half8tofloat8) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_float8tohalf8) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_syncthreads_and) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_syncthreads_or) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_nano_sleep) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_atomic_cas) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_printf) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(cuda_ldg) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)) - .set_num_inputs(2); - -TIR_DEFINE_BUILTIN_FUNC(cuda_get_tmem_addr) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(ptx_exp2).set_attr( - "TCallEffectKind", static_cast(CallEffectKind::kPure)); - -TIR_DEFINE_BUILTIN_FUNC(ptx_rcp).set_attr( - "TCallEffectKind", static_cast(CallEffectKind::kPure)); - -TIR_DEFINE_BUILTIN_FUNC(ptx_any_sync) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kPure)); - -TIR_DEFINE_BUILTIN_FUNC(ptx_reduce3_max_f32) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kPure)); - -TIR_DEFINE_BUILTIN_FUNC(ptx_reduce3_min_f32) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kPure)); - -// PTX scalar / packed floating-point arithmetic, DPS form (writes to *d_addr). -// add/sub/mul: 2 sources, 1 destination. -// fma: 3 sources, 1 destination. -// Modifiers (rounding / ftz / sat) are codegen attrs. -// kOpaque because all four kinds write through the destination pointer. -TIR_DEFINE_BUILTIN_FUNC(ptx_add_f32) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(ptx_add_f32x2) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(ptx_add_f64) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(ptx_sub_f32) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(ptx_sub_f32x2) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(ptx_sub_f64) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(ptx_mul_f32) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(ptx_mul_f32x2) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(ptx_mul_f64) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -TIR_DEFINE_BUILTIN_FUNC(ptx_fma_f32) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(ptx_fma_f32x2) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); -TIR_DEFINE_BUILTIN_FUNC(ptx_fma_f64) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kOpaque)); - -// max stays value-returning + kPure (no .sat, not in the add/sub/mul/fma family). -TIR_DEFINE_BUILTIN_FUNC(ptx_max_f32) - .set_attr("TCallEffectKind", static_cast(CallEffectKind::kPure)); } // namespace builtin } // namespace tirx } // namespace tvm diff --git a/src/tirx/op/op.cc b/src/tirx/op/op.cc index 312c8b79dd95..64f7f575d6a1 100644 --- a/src/tirx/op/op.cc +++ b/src/tirx/op/op.cc @@ -25,6 +25,7 @@ #include #include +#include #include #include #include @@ -85,13 +86,13 @@ Type GetType(const PrimExpr& expr) { } } + static const Op& type_annotation_op = Op::Get("tirx.type_annotation"); if (auto* access = expr.as()) { if (access->op.same_as(builtin::tvm_access_ptr())) { TVM_FFI_ICHECK(access->args.size()) << "Builtin tvm_access_ptr() may not have empty arguments"; auto type_annotation = Downcast(access->args[0]); - static auto builtin_op = Op::Get("tirx.type_annotation"); - TVM_FFI_ICHECK(type_annotation->op.same_as(builtin_op)) + TVM_FFI_ICHECK(type_annotation->op.same_as(type_annotation_op)) << "Expected the first argument of builtin tvm_access_ptr() " << "to be a type annotation, but found " << type_annotation->op; return PointerType(PrimType(type_annotation->dtype)); @@ -99,8 +100,7 @@ Type GetType(const PrimExpr& expr) { if (access->op.same_as(builtin::ptr_byte_offset())) { TVM_FFI_ICHECK_EQ(access->args.size(), 3U); auto type_annotation = Downcast(access->args[2]); - static auto builtin_op = Op::Get("tirx.type_annotation"); - TVM_FFI_ICHECK(type_annotation->op.same_as(builtin_op)) + TVM_FFI_ICHECK(type_annotation->op.same_as(type_annotation_op)) << "Expected the third argument of builtin ptr_byte_offset() " << "to be a type annotation, but found " << type_annotation->op; return PointerType(PrimType(type_annotation->dtype)); @@ -907,8 +907,8 @@ PrimExpr pow(PrimExpr x, PrimExpr y, Span span) { } } - static auto op = Op::Get("tirx.pow"); - return tirx::Call(x.dtype(), op, {x, y}, {}, span); + static const Op& pow_op = Op::Get("tirx.pow"); + return tirx::Call(x.dtype(), pow_op, {x, y}, {}, span); } TVM_TIR_REGISTER_PURE_BINARY_OP("pow").set_attr("TVectorizable", true); @@ -928,8 +928,8 @@ PrimExpr abs(PrimExpr x, Span span) { if (fx) { return FloatImm(x.dtype(), std::fabs(fx->value), fx->span); } - static auto op = Op::Get("tirx.fabs"); - return tirx::Call(x.dtype(), op, {x}, {}, span); + static const Op& fabs_op = Op::Get("tirx.fabs"); + return tirx::Call(x.dtype(), fabs_op, {x}, {}, span); } else if (x.dtype().is_uint()) { return x; } else { @@ -952,12 +952,13 @@ PrimExpr isnan(PrimExpr x, Span span) { if (fx) { return make_const(t, std::isnan(fx->value), fx->span); } - static auto op = Op::Get("tirx.isnan"); if (x.dtype().bits() == 16) { - return tirx::Call(t, op, {cast(DataType::Float(32, t.lanes()), std::move(x), span)}, {}, + static const Op& isnan_op = Op::Get("tirx.isnan"); + return tirx::Call(t, isnan_op, {cast(DataType::Float(32, t.lanes()), std::move(x), span)}, {}, span); } else { - return tirx::Call(t, op, {x}, {}, span); + static const Op& isnan_op = Op::Get("tirx.isnan"); + return tirx::Call(t, isnan_op, {x}, {}, span); } } else { TVM_FFI_THROW(InternalError) << "Data type " << x.dtype() @@ -1044,8 +1045,8 @@ PrimExpr prod(PrimExpr source, ffi::Array rdom, ffi::Array in PrimExpr fmod(PrimExpr x, PrimExpr y, Span span) { BinaryOpMatchTypes(x, y, span); TVM_FFI_ICHECK(x.dtype().is_float()) << "fmod only applies to float"; - static auto op = Op::Get("tirx.fmod"); - return tirx::Call(x.dtype(), op, {x, y}, {}, span); + static const Op& fmod_op = Op::Get("tirx.fmod"); + return tirx::Call(x.dtype(), fmod_op, {x, y}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("fmod"); @@ -1058,8 +1059,8 @@ PrimExpr floor(PrimExpr x, Span span) { using tirx::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::floor(fx->value), fx->span); - static auto op = Op::Get("tirx.floor"); - return tirx::Call(x.dtype(), op, {x}, {}, span); + static const Op& floor_op = Op::Get("tirx.floor"); + return tirx::Call(x.dtype(), floor_op, {x}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("floor").set_attr("TVectorizable", true); @@ -1072,8 +1073,8 @@ PrimExpr ceil(PrimExpr x, Span span) { using tirx::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::ceil(fx->value), fx->span); - static auto op = Op::Get("tirx.ceil"); - return tirx::Call(x.dtype(), op, {x}, {}, span); + static const Op& ceil_op = Op::Get("tirx.ceil"); + return tirx::Call(x.dtype(), ceil_op, {x}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("ceil").set_attr("TVectorizable", true); @@ -1086,8 +1087,8 @@ PrimExpr round(PrimExpr x, Span span) { using tirx::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span); - static auto op = Op::Get("tirx.round"); - return tirx::Call(x.dtype(), op, {x}, {}, span); + static const Op& round_op = Op::Get("tirx.round"); + return tirx::Call(x.dtype(), round_op, {x}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("round").set_attr("TVectorizable", true); @@ -1100,8 +1101,8 @@ PrimExpr nearbyint(PrimExpr x, Span span) { using tirx::FloatImmNode; const FloatImmNode* fx = x.as(); if (fx) return FloatImm(x.dtype(), std::nearbyint(fx->value), fx->span); - static auto op = Op::Get("tirx.nearbyint"); - return tirx::Call(x.dtype(), op, {x}, {}, span); + static const Op& nearbyint_op = Op::Get("tirx.nearbyint"); + return tirx::Call(x.dtype(), nearbyint_op, {x}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("nearbyint"); @@ -1117,8 +1118,8 @@ PrimExpr trunc(PrimExpr x, Span span) { return FloatImm(x.dtype(), (fx->value < 0 ? std::ceil(fx->value) : std::floor(fx->value)), fx->span); } - static auto op = Op::Get("tirx.trunc"); - return tirx::Call(x.dtype(), op, {x}, {}, span); + static const Op& trunc_op = Op::Get("tirx.trunc"); + return tirx::Call(x.dtype(), trunc_op, {x}, {}, span); } TVM_TIR_REGISTER_PURE_UNARY_OP("trunc").set_attr("TVectorizable", true); diff --git a/src/tirx/script/builder/frame.cc b/src/tirx/script/builder/frame.cc index d7dc9a4f91a1..f19da7261039 100644 --- a/src/tirx/script/builder/frame.cc +++ b/src/tirx/script/builder/frame.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include #include @@ -311,7 +312,7 @@ void ComposeOpFrameNode::ExitWithScope() { << stmt; ops.push_back(ffi::GetRef(op_call)); } - auto compose_op_op = tvm::Op::Get("tirx.tile.compose_op"); + static const Op& compose_op_op = Op::Get("tirx.tile.compose_op"); AddToParent(tvm::tirx::TilePrimitiveCall(compose_op_op, ops, workspace, config, dispatch)); } diff --git a/src/tirx/transform/lower_intrin.cc b/src/tirx/transform/lower_intrin.cc index 0b859ef9956f..99cc26464503 100644 --- a/src/tirx/transform/lower_intrin.cc +++ b/src/tirx/transform/lower_intrin.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -73,7 +74,8 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { if (Op::HasAttrMap(pattern)) { attr_maps_.push_back(Op::GetAttrMap(pattern)); if (fma_ == nullptr) { - fma_ = (*attr_maps_.rbegin()).get(Op::Get("tirx.fma"), nullptr); + static const Op& fma_op = Op::Get("tirx.fma"); + fma_ = (*attr_maps_.rbegin()).get(fma_op, nullptr); } } } diff --git a/src/tirx/transform/lower_tvm_builtin.cc b/src/tirx/transform/lower_tvm_builtin.cc index 3b1336515721..c55432b3f187 100644 --- a/src/tirx/transform/lower_tvm_builtin.cc +++ b/src/tirx/transform/lower_tvm_builtin.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -271,7 +272,9 @@ class BuiltinLower : public StmtExprMutator { Stmt alloc_nullptr_check = IfThenElse( Call(DataType::Bool(), builtin::isnullptr(), {op->buffer->data}), throw_last_error); - PrimExpr free_op = Call(DataType::Int(32), Op::Get("tirx.TVMBackendFreeWorkspace"), + static const Op& free_workspace_op = Op::Get("tirx.TVMBackendFreeWorkspace"); + static const Op& alloc_workspace_op = Op::Get("tirx.TVMBackendAllocWorkspace"); + PrimExpr free_op = Call(DataType::Int(32), free_workspace_op, {cast(DataType::Int(32), device_type_.value()), cast(DataType::Int(32), device_id_.value()), op->buffer->data}); Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); @@ -279,12 +282,12 @@ class BuiltinLower : public StmtExprMutator { // Push free to enclosing scope's pending_frees (LIFO ordering preserved). scope_.Current().pending_frees.push_back(free_stmt); - Stmt alloc_bind = Bind(op->buffer->data, - Call(op->buffer->data.dtype(), Op::Get("tirx.TVMBackendAllocWorkspace"), - {cast(DataType::Int(32), device_type_.value()), - cast(DataType::Int(32), device_id_.value()), total_bytes, - IntImm(DataType::Int(32), op->buffer->dtype.code()), - IntImm(DataType::Int(32), op->buffer->dtype.bits())})); + Stmt alloc_bind = + Bind(op->buffer->data, Call(op->buffer->data.dtype(), alloc_workspace_op, + {cast(DataType::Int(32), device_type_.value()), + cast(DataType::Int(32), device_id_.value()), total_bytes, + IntImm(DataType::Int(32), op->buffer->dtype.code()), + IntImm(DataType::Int(32), op->buffer->dtype.bits())})); return SeqStmt({alloc_bind, alloc_nullptr_check}); } @@ -529,8 +532,9 @@ class BuiltinLower : public StmtExprMutator { auto* call_pattern = arg.as(); if (call_pattern && call_pattern->op.same_as(builtin::anylist_getitem())) { // call runtime function to set anylist + static const Op& anylist_set_packed_arg_op = Op::Get("tirx.TVMBackendAnyListSetPackedArg"); prep_seq->emplace_back(Evaluate(Call( - DataType::Int(32), Op::Get("tirx.TVMBackendAnyListSetPackedArg"), + DataType::Int(32), anylist_set_packed_arg_op, {call_pattern->args[0], call_pattern->args[1], args_stack, ConstInt32(stack_offset)}))); } else { DataType api_dtype = APIType(arg.dtype()); @@ -587,7 +591,9 @@ class BuiltinLower : public StmtExprMutator { PrimExpr ret_offset = call->args[3]; auto& prep_seq = prep_seq_stack_.back(); prep_seq.emplace_back(Evaluate(call)); - return Call(DataType::Int(32), Op::Get("tirx.TVMBackendAnyListMoveFromPackedReturn"), + static const Op& anylist_move_from_packed_return_op = + Op::Get("tirx.TVMBackendAnyListMoveFromPackedReturn"); + return Call(DataType::Int(32), anylist_move_from_packed_return_op, {list_handle, list_index, args_stack, ret_offset}); } /*! diff --git a/src/tirx/transform/lower_warp_memory.cc b/src/tirx/transform/lower_warp_memory.cc index 1cf4a3e6300d..9d70c5d2efda 100644 --- a/src/tirx/transform/lower_warp_memory.cc +++ b/src/tirx/transform/lower_warp_memory.cc @@ -30,6 +30,7 @@ #include #include #include +#include #include #include #include @@ -49,18 +50,6 @@ namespace tvm { namespace tirx { -namespace { - -bool IsOp(const CallNode* call, const Op& compat_op, const char* canonical_name) { - if (call->op.same_as(compat_op)) { - return true; - } - const auto* op_node = call->op.as(); - return op_node != nullptr && op_node->name == canonical_name; -} - -} // namespace - // Rewrite Rule // // There is no special warp memory in most GPUs. @@ -129,20 +118,22 @@ class WarpStoreCoeffFinder : private StmtExprVisitor { private: /// Visitor implementation void VisitExpr_(const CallNode* op) final { - if (IsOp(op, builtin::ptx_ldmatrix(), "tirx.ptx.ldmatrix") && - op->args[3].as() == buffer_) { + static const Op& ptx_ldmatrix_op = Op::Get("tirx.ptx.ldmatrix"); + static const Op& mma_fill_op = Op::Get("tirx.mma_fill"); + static const Op& ptx_ldmatrix_legacy_op = Op::Get("tirx.ptx.ldmatrix_legacy"); + static const Op& mma_fill_legacy_op = Op::Get("tirx.mma_fill_legacy"); + if (op->op.same_as(ptx_ldmatrix_op) && op->args[3].as() == buffer_) { UpdatePattern(op->args[4]); - } else if (op->op.same_as(builtin::mma_fill()) && op->args[1].as() == buffer_) { + } else if (op->op.same_as(mma_fill_op) && op->args[1].as() == buffer_) { auto* local_size = op->args[0].as(); TVM_FFI_ICHECK(local_size) << "Integer expected for the first argument of mma_fill"; warp_coeff_ = local_size->value; - } else if (IsOp(op, builtin::ptx_ldmatrix_legacy(), "tirx.ptx.ldmatrix_legacy") && - op->args[3].as() == buffer_) { + } else if (op->op.same_as(ptx_ldmatrix_legacy_op) && op->args[3].as() == buffer_) { // ldmatrix writes the warp buffer; its local_offset carries // ``... + lift(local_size) * tx`` from which the warp coefficient // is derived. UpdatePattern(op->args[4]); - } else if (op->op.same_as(builtin::mma_fill_legacy()) && op->args[1].as() == buffer_) { + } else if (op->op.same_as(mma_fill_legacy_op) && op->args[1].as() == buffer_) { auto* local_size = op->args[0].as(); TVM_FFI_ICHECK(local_size) << "Integer expected for the first argument of mma_fill_legacy"; warp_coeff_ = local_size->value; @@ -308,37 +299,45 @@ class WarpAccessRewriter : protected StmtExprMutator { } PrimExpr VisitExpr_(const CallNode* op) override { - if (IsOp(op, builtin::ptx_mma(), "tirx.ptx.mma")) { + static const Op& ptx_mma_op = Op::Get("tirx.ptx.mma"); + static const Op& ptx_ldmatrix_op = Op::Get("tirx.ptx.ldmatrix"); + static const Op& mma_store_op = Op::Get("tirx.mma_store"); + static const Op& mma_fill_op = Op::Get("tirx.mma_fill"); + static const Op& ptx_mma_legacy_op = Op::Get("tirx.ptx.mma_legacy"); + static const Op& ptx_ldmatrix_legacy_op = Op::Get("tirx.ptx.ldmatrix_legacy"); + static const Op& mma_store_legacy_op = Op::Get("tirx.mma_store_legacy"); + static const Op& mma_fill_legacy_op = Op::Get("tirx.mma_fill_legacy"); + if (op->op.same_as(ptx_mma_op)) { return RewriteIndicesAt(op, {6, 8, 10}); } - if (IsOp(op, builtin::ptx_ldmatrix(), "tirx.ptx.ldmatrix")) { + if (op->op.same_as(ptx_ldmatrix_op)) { return RewriteIndicesAt(op, {3}); } - if (op->op.same_as(builtin::mma_store())) { + if (op->op.same_as(mma_store_op)) { return RewriteIndicesAt(op, {3}); } - if (op->op.same_as(builtin::mma_fill())) { + if (op->op.same_as(mma_fill_op)) { return RewriteIndicesAt(op, {1}); } // Legacy variants: (ptr_var, offset) pairs in apache positions. - if (IsOp(op, builtin::ptx_mma_legacy(), "tirx.ptx.mma_legacy")) { + if (op->op.same_as(ptx_mma_legacy_op)) { return RewriteIndicesAt(op, {6, 8, 10}); } - if (IsOp(op, builtin::ptx_ldmatrix_legacy(), "tirx.ptx.ldmatrix_legacy")) { + if (op->op.same_as(ptx_ldmatrix_legacy_op)) { // args: trans, num, type, local_ptr, local_offset, smem_ptr_call, smem_offset // Only local_ptr is a raw warp buffer Var; smem_ptr is an // access_ptr Call wrapping a shared-scope var. return RewriteIndicesAt(op, {3}); } - if (op->op.same_as(builtin::mma_store_legacy())) { + if (op->op.same_as(mma_store_legacy_op)) { // args: m, n, dst_ptr, src_ptr, src_offset, dst_stride return RewriteIndicesAt(op, {3}); } - if (op->op.same_as(builtin::mma_fill_legacy())) { + if (op->op.same_as(mma_fill_legacy_op)) { // args: local_size, local_ptr, offset return RewriteIndicesAt(op, {1}); } diff --git a/src/tirx/transform/remove_no_op.cc b/src/tirx/transform/remove_no_op.cc index 8ae06ea9a37b..6394eb21980a 100644 --- a/src/tirx/transform/remove_no_op.cc +++ b/src/tirx/transform/remove_no_op.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -224,10 +225,12 @@ class NoOpRemover : public arith::IRMutatorWithAnalyzer { bool HasSideEffect(const PrimExpr& value) { if (ignore_profiler_call_) { if (const CallNode* call = value.as()) { - if (call->op.same_as(builtin::timer_init_cuda()) || - call->op.same_as(builtin::timer_start_cuda()) || - call->op.same_as(builtin::timer_end_cuda()) || - call->op.same_as(builtin::timer_finalize_cuda())) { + static const Op& timer_init_cuda_op = Op::Get("tirx.timer_init_cuda"); + static const Op& timer_start_cuda_op = Op::Get("tirx.timer_start_cuda"); + static const Op& timer_end_cuda_op = Op::Get("tirx.timer_end_cuda"); + static const Op& timer_finalize_cuda_op = Op::Get("tirx.timer_finalize_cuda"); + if (call->op.same_as(timer_init_cuda_op) || call->op.same_as(timer_start_cuda_op) || + call->op.same_as(timer_end_cuda_op) || call->op.same_as(timer_finalize_cuda_op)) { return false; } } diff --git a/src/tirx/transform/tile_primitive_dispatch.cc b/src/tirx/transform/tile_primitive_dispatch.cc index 9639fce1db2d..062b6f030293 100644 --- a/src/tirx/transform/tile_primitive_dispatch.cc +++ b/src/tirx/transform/tile_primitive_dispatch.cc @@ -25,6 +25,7 @@ #include #include +#include #include #include #include @@ -34,7 +35,6 @@ #include #include #include -#include #include #include @@ -125,7 +125,8 @@ class ElectSyncFinder : public StmtExprVisitor { void VisitExpr_(const CallNode* op) final { auto is_canonical_elect_sync = [&]() { - if (op->op.same_as(tirx::builtin::ptx_elect_sync())) return true; + static const Op& ptx_elect_sync_op = Op::Get("tirx.ptx_elect_sync"); + if (op->op.same_as(ptx_elect_sync_op)) return true; if (auto call_op = op->op.as()) { return call_op.value()->name == "tirx.ptx.elect_sync"; } diff --git a/tests/python/s_tir/dlight/test_gpu_matmul_tensorize.py b/tests/python/s_tir/dlight/test_gpu_matmul_tensorize.py index 7d03f49a75a8..d12d713d3bce 100644 --- a/tests/python/s_tir/dlight/test_gpu_matmul_tensorize.py +++ b/tests/python/s_tir/dlight/test_gpu_matmul_tensorize.py @@ -755,7 +755,7 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha T.reads() T.writes(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) A_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) - T.make_filled_simdgroup_matrix(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.float32(0), 8, 8) + T.metal.make_filled_simdgroup_matrix(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.float32(0), 8, 8) for ax3_0 in range(128): for ax0_1, ax1_ax2_fused_0 in T.grid(1, 1): for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): @@ -791,7 +791,7 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha T.writes(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) A_1 = T.match_buffer(A_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", offset_factor=1) C_1 = T.match_buffer(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) - T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(False)) + T.metal.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(False)) for ax0_0, ax1_0_1 in T.grid(2, 1): with T.sblock("B_reindex_shared_metal.simdgroup_o"): v0_o = T.axis.spatial(1, 0) @@ -801,7 +801,7 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha T.writes(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 8:v1_o * 8 + 8]) A_1 = T.match_buffer(B_reindex_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", offset_factor=1) C_1 = T.match_buffer(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 8:v1_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) - T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(True)) + T.metal.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(True)) for ax1_2, ax2_2 in T.grid(2, 2): with T.sblock("C_update_o"): v0_o = T.axis.spatial(1, ax0) @@ -813,7 +813,7 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha A_1 = T.match_buffer(A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v3_o * 8:v3_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) B_1 = T.match_buffer(B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("B_s0", "B_s1"), scope="metal.simdgroup", offset_factor=1) C_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) - T.simdgroup_multiply_accumulate(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, B_1.data, B_1.elem_offset // B_1.strides[0] // 8 * (B_1.strides[0] // 8) + B_1.elem_offset % B_1.strides[0] // 8, C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8) + T.metal.simdgroup_multiply_accumulate(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, B_1.data, B_1.elem_offset // B_1.strides[0] // 8 * (B_1.strides[0] // 8) + B_1.elem_offset % B_1.strides[0] // 8, C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8) for ax0_1, ax1_0_1, ax2_0_1 in T.grid(1, 2, 2): with T.sblock("C_reindex_pad_metal.simdgroup_o"): v0_o = T.axis.spatial(1, ax0_1) @@ -823,7 +823,7 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha T.writes(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) A_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) C_1 = T.match_buffer(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="shared", offset_factor=1) - T.simdgroup_store(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), C_1.data, C_1.elem_offset, C_1.strides[0] * 8, 2), C_1.strides[0], 8, 8, T.bool(False)) + T.metal.simdgroup_store(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), C_1.data, C_1.elem_offset, C_1.strides[0] * 8, 2), C_1.strides[0], 8, 8, T.bool(False)) for ax0_1, ax1_ax2_fused_0 in T.grid(1, 2): for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): @@ -900,7 +900,7 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f T.reads() T.writes(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) A_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) - T.make_filled_simdgroup_matrix(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.float32(0), 8, 8) + T.metal.make_filled_simdgroup_matrix(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.float32(0), 8, 8) for ax3_0 in range(128): for ax0_1, ax1_ax2_fused_0 in T.grid(1, 1): for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): @@ -936,7 +936,7 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f T.writes(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) A_1 = T.match_buffer(A_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", offset_factor=1) C_1 = T.match_buffer(A_reindex_pad_shared_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) - T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(False)) + T.metal.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(False)) for ax0_0, ax1_0_1 in T.grid(2, 1): with T.sblock("B_reindex_shared_metal.simdgroup_o"): v0_o = T.axis.spatial(1, 0) @@ -946,7 +946,7 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f T.writes(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 8:v1_o * 8 + 8]) A_1 = T.match_buffer(B_reindex_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="shared", offset_factor=1) C_1 = T.match_buffer(B_reindex_shared_metal_simdgroup[v0_o, v2_o * 8:v2_o * 8 + 8, v1_o * 8:v1_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) - T.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(True)) + T.metal.simdgroup_load(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), A_1.data, A_1.elem_offset, A_1.strides[0] * 8, 1), A_1.strides[0], 8, 8, T.bool(True)) for ax1_2, ax2_2 in T.grid(2, 2): with T.sblock("NT_matmul_update_o"): v0_o = T.axis.spatial(1, ax0) @@ -958,7 +958,7 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f A_1 = T.match_buffer(A_reindex_pad_shared_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v3_o * 8:v3_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) B = T.match_buffer(B_reindex_shared_metal_simdgroup[0, v3_o * 8:v3_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("B_s0", "B_s1"), scope="metal.simdgroup", offset_factor=1) C_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[0, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="metal.simdgroup", offset_factor=1) - T.simdgroup_multiply_accumulate(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, B.data, B.elem_offset // B.strides[0] // 8 * (B.strides[0] // 8) + B.elem_offset % B.strides[0] // 8, C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8) + T.metal.simdgroup_multiply_accumulate(C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8, A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, B.data, B.elem_offset // B.strides[0] // 8 * (B.strides[0] // 8) + B.elem_offset % B.strides[0] // 8, C_1.data, C_1.elem_offset // C_1.strides[0] // 8 * (C_1.strides[0] // 8) + C_1.elem_offset % C_1.strides[0] // 8) for ax0_1, ax1_0_1, ax2_0_1 in T.grid(1, 2, 2): with T.sblock("C_reindex_pad_metal.simdgroup_o"): v0_o = T.axis.spatial(1, ax0_1) @@ -968,7 +968,7 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f T.writes(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8]) A_1 = T.match_buffer(C_reindex_pad_metal_simdgroup[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("A_s0", "A_s1"), scope="metal.simdgroup", offset_factor=1) C_1 = T.match_buffer(C_reindex_pad_shared[v0_o, v1_o * 8:v1_o * 8 + 8, v2_o * 8:v2_o * 8 + 8], (8, 8), "float16", strides=("C_s0", "C_s1"), scope="shared", offset_factor=1) - T.simdgroup_store(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), C_1.data, C_1.elem_offset, C_1.strides[0] * 8, 2), C_1.strides[0], 8, 8, T.bool(False)) + T.metal.simdgroup_store(A_1.data, A_1.elem_offset // A_1.strides[0] // 8 * (A_1.strides[0] // 8) + A_1.elem_offset % A_1.strides[0] // 8, T.tvm_access_ptr(T.type_annotation("float16"), C_1.data, C_1.elem_offset, C_1.strides[0] * 8, 2), C_1.strides[0], 8, 8, T.bool(False)) for ax0_1, ax1_ax2_fused_0 in T.grid(1, 2): for ax1_ax2_fused_1 in T.thread_binding(4, thread="threadIdx.z"): for ax1_ax2_fused_2 in T.thread_binding(1, thread="threadIdx.y"): diff --git a/tests/python/tirx-base/test_tir_op_types.py b/tests/python/tirx-base/test_tir_op_types.py index 2ffce7dce8c6..5fb5de5f4a95 100644 --- a/tests/python/tirx-base/test_tir_op_types.py +++ b/tests/python/tirx-base/test_tir_op_types.py @@ -18,6 +18,7 @@ import tvm import tvm.testing from tvm import tirx +from tvm.backend.cuda import op as _cuda_op def test_tir_op_tvm_tuple(): @@ -149,7 +150,7 @@ def test_tir_op_ptx_mma(): buffer_a = tirx.decl_buffer([32], "int4", scope="local") buffer_b = tirx.decl_buffer([16], "uint4", scope="local") buffer_c = tirx.decl_buffer([4], "int32", scope="local") - expr = tirx.ptx_mma_legacy( + expr = _cuda_op.ptx_mma_legacy( "m8n8k32", "row", "col", @@ -172,7 +173,7 @@ def test_tir_op_ptx_mma_sp(): buffer_b = tirx.decl_buffer([16], "uint4", scope="local") buffer_c = tirx.decl_buffer([4], "int32", scope="local") buffer_d = tirx.decl_buffer([1], "uint32", scope="local") - expr = tirx.ptx_mma_sp_legacy( + expr = _cuda_op.ptx_mma_sp_legacy( "m8n8k32", "row", "col", @@ -200,7 +201,7 @@ def test_tir_op_mma_store(): buffer = tirx.decl_buffer( [16, 16], dtype="int32", scope="global", offset_factor=1, strides=[x, y] ) - expr = tirx.mma_store( + expr = _cuda_op.mma_store( "int32", 16, 16, @@ -214,7 +215,7 @@ def test_tir_op_mma_store(): def test_tir_op_mma_fill(): buffer_w = tirx.decl_buffer([16, 8], dtype="int32", scope="warp", offset_factor=1) - expr = tirx.mma_fill("int32", 8, buffer_w.data, buffer_w.elem_offset) + expr = _cuda_op.mma_fill("int32", 8, buffer_w.data, buffer_w.elem_offset) assert expr.op.name == "tirx.mma_fill" @@ -222,7 +223,7 @@ def test_op_ptx_ldmatrix(): buffer_shared = tirx.decl_buffer([16, 16], "float16", scope="shared") buffer_local = tirx.decl_buffer([8], "float16", scope="local") # New API: 4 scatter-form dst handles for .x4.b16 (one per output register). - expr = tirx.ptx_ldmatrix( + expr = _cuda_op.ptx_ldmatrix( False, 4, ".b16", @@ -238,14 +239,14 @@ def test_op_ptx_ldmatrix(): def test_op_ptx_cp_async(): buffer_shared = tirx.decl_buffer([16, 16], "float16", scope="shared") buffer_local = tirx.decl_buffer([8], "float16", scope="local") - expr = tirx.ptx_cp_async_legacy(buffer_shared.data, 0, buffer_local.data, 0, 16) + expr = _cuda_op.ptx_cp_async_legacy(buffer_shared.data, 0, buffer_local.data, 0, 16) assert expr.op.name == "tirx.ptx.cp_async" def test_op_ptx_cp_async_bulk(): buffer_shared = tirx.decl_buffer([16, 16], "float16", scope="shared") buffer_local = tirx.decl_buffer([8], "float16", scope="local") - expr = tirx.ptx_cp_async_bulk("float16", buffer_shared.data, 0, buffer_local.data, 0, 16, 0) + expr = _cuda_op.ptx_cp_async_bulk("float16", buffer_shared.data, 0, buffer_local.data, 0, 16, 0) assert expr.op.name == "tirx.ptx.cp_async_bulk" diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py index 340eb9809493..75faf61366fe 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py @@ -32,12 +32,12 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx -from tvm.tirx.layout import S, TileLayout # Force the fallback dispatch to register before any test compiles a kernel. # Without this import, in fresh pytest workers the `copy/fallback` variant # isn't yet registered when the dispatcher snapshots its registry. -from tvm.tirx.operator.tile_primitive.cuda.copy import fallback as _fallback_module # noqa: F401 +from tvm.tirx.cuda.operator.tile_primitive.copy import fallback as _fallback_module # noqa: F401 +from tvm.tirx.layout import S, TileLayout def _round_trip_shapes_and_threads(): diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py index 86a33b940f9d..676d8d95ae5f 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py @@ -253,7 +253,7 @@ def copy_sync(A_ptr: T.handle, B_ptr: T.handle) -> None: def _align( g_layout, g_shape, s_layout, s_shape, elem_bits, thread_cnt, g_region=None, s_region=None ): - from tvm.tirx.operator.tile_primitive.cuda.copy._common import align_layouts_gs + from tvm.tirx.cuda.operator.tile_primitive.copy._common import align_layouts_gs target = tvm.target.Target("cuda") if g_region is None: diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_swizzle_iter.py b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_swizzle_iter.py index c2a5a73fb5f7..1231b338ba08 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_swizzle_iter.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_swizzle_iter.py @@ -37,12 +37,12 @@ import tvm from tvm.tirx import Var as _TirVar -from tvm.tirx.expr import IntImm as _IntImm -from tvm.tirx.layout import ComposeLayout, S, SwizzleLayout, TileLayout -from tvm.tirx.operator.tile_primitive.cuda.copy._swizzle_iter import ( +from tvm.tirx.cuda.operator.tile_primitive.copy._swizzle_iter import ( get_swizzle, try_recognize, ) +from tvm.tirx.expr import IntImm as _IntImm +from tvm.tirx.layout import ComposeLayout, S, SwizzleLayout, TileLayout # ---------------------------------------------------------------------------- # Pure-Python reference: SwizzleLayout's Apply, plus the proof's formula. @@ -312,7 +312,7 @@ def test_recognize_linear_iter_pure_case_1d(): of the swizzle period 2^(p+at+sw) (pure Case 1.D, swizzle has no XOR effect). The iter is stored as a LinearIter (no bit decomposition). """ - from tvm.tirx.operator.tile_primitive.cuda.copy._swizzle_iter import ( + from tvm.tirx.cuda.operator.tile_primitive.copy._swizzle_iter import ( _BitIter, _LinearIter, ) @@ -352,7 +352,7 @@ def test_emit_mixed_linear_bit_correctness(): """Brute-force: for a mixed (LinearIter outer, BitIter inner) pattern, emit_iter_offset's prediction must equal the actual swizzle output for every (tid, k) — including the non-pow2 outer extent's coord 2.""" - from tvm.tirx.operator.tile_primitive.cuda.copy._swizzle_iter import ( + from tvm.tirx.cuda.operator.tile_primitive.copy._swizzle_iter import ( _LinearIter, ) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_dsmem.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_dsmem.py index 3e3070e8994f..5493fe0c28e6 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_dsmem.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_dsmem.py @@ -31,9 +31,9 @@ from tvm.script import tirx as T from tvm.script.tirx import tile as Tx from tvm.tirx import IntImm, Var +from tvm.tirx.cuda.operator.tile_primitive.copy_async.dsmem import copy_dsmem_impl from tvm.tirx.exec_scope import ExecScope from tvm.tirx.layout import S, TileLayout -from tvm.tirx.operator.tile_primitive.cuda.copy_async.dsmem import copy_dsmem_impl from tvm.tirx.operator.tile_primitive.dispatch_context import DispatchContext from tvm.tirx.operator.tile_primitive.dispatcher import DispatchFail from tvm.tirx.operator.tile_primitive.ops import CopyAsync diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_smem_tmem.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_smem_tmem.py index 036bd786a24a..3cdf31efd864 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_smem_tmem.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_smem_tmem.py @@ -31,8 +31,8 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx +from tvm.tirx.cuda.operator.tile_primitive.tma_utils import SwizzleMode, mma_shared_layout from tvm.tirx.layout import R, S, TCol, TileLayout, TLane -from tvm.tirx.operator.tile_primitive.cuda.tma_utils import SwizzleMode, mma_shared_layout T_LAY_BASIC = TileLayout(S[(32, 16) : (1 @ TLane, 1 @ TCol)] + R[4 : 32 @ TLane]) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py index 933b866bdb64..1b0455e27234 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tma.py @@ -27,13 +27,13 @@ from tvm.script import tirx as T from tvm.script.tirx import tile as Tx from tvm.tirx import IntImm, StringImm, Var -from tvm.tirx.exec_scope import ExecScope -from tvm.tirx.layout import S, TileLayout -from tvm.tirx.operator.tile_primitive.cuda.tma_utils import ( +from tvm.tirx.cuda.operator.tile_primitive.tma_utils import ( mma_atom_layout, mma_atom_shape, mma_shared_layout, ) +from tvm.tirx.exec_scope import ExecScope +from tvm.tirx.layout import S, TileLayout from tvm.tirx.operator.tile_primitive.dispatch_context import DispatchContext from tvm.tirx.operator.tile_primitive.ops import CopyAsync from tvm.tirx.stmt import DeclBuffer @@ -95,7 +95,7 @@ def _make_tma_call( """ from tvm.ir import Range from tvm.tirx import Var - from tvm.tirx.operator.tile_primitive.cuda.copy_async.tma import copy_tma_impl + from tvm.tirx.cuda.operator.tile_primitive.copy_async.tma import copy_tma_impl from tvm.tirx.stmt import BufferRegion g_buf = tvm.tirx.decl_buffer(g_shape, dtype, "A", layout=gmem_layout) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py index 3aa02bb5e2f0..c20df63bebf0 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py @@ -23,10 +23,10 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx -from tvm.tirx.layout import S, TileLayout, laneid, tid_in_wg, tx, warpid -from tvm.tirx.operator.tile_primitive.cuda.layout_utils import ( +from tvm.tirx.cuda.operator.tile_primitive.layout_utils import ( cast_layout_supported_for_local as _cast_layout_supported_for_local, ) +from tvm.tirx.layout import S, TileLayout, laneid, tid_in_wg, tx, warpid @pytest.mark.parametrize( @@ -1038,10 +1038,10 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: def test_cast_layout_partition_and_validation(): """Partition table (simplified): partition structure and _cast_layout_supported_for_local.""" - from tvm.tirx.layout import Axis, Iter - from tvm.tirx.operator.tile_primitive.cuda.layout_utils import ( + from tvm.tirx.cuda.operator.tile_primitive.layout_utils import ( get_layout_thread_local_partition as _get_layout_thread_local_partition, ) + from tvm.tirx.layout import Axis, Iter m_axis = Axis.get("m") @@ -1148,7 +1148,7 @@ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None: def test_cast_joint_decomposition_extents_order(): """Test joint decomposition uses thread dims in layout order with correct extents.""" - from tvm.tirx.operator.tile_primitive.cuda.layout_utils import ( + from tvm.tirx.cuda.operator.tile_primitive.layout_utils import ( get_layout_thread_local_partition as _get_layout_thread_local_partition, ) diff --git a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py index 8c32bbe04839..e0a270e7091a 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py @@ -32,14 +32,14 @@ from tvm.ir.type import PointerType, PrimType from tvm.script import tirx as T from tvm.script.tirx import tile as Tx -from tvm.tirx.layout import S, TCol, TileLayout, TLane, tcgen05_atom_layout -from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg -from tvm.tirx.operator.tile_primitive.cuda.gemm_async import sf_tmem_layout -from tvm.tirx.operator.tile_primitive.cuda.tma_utils import ( +from tvm.tirx.cuda.operator.tile_primitive.gemm_async import sf_tmem_layout +from tvm.tirx.cuda.operator.tile_primitive.tma_utils import ( mma_atom_layout, mma_atom_shape, mma_shared_layout, ) +from tvm.tirx.layout import S, TCol, TileLayout, TLane, tcgen05_atom_layout +from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg # --------------------------------------------------------------------------- # Shared test helpers @@ -1960,5 +1960,89 @@ def gemm_async(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3) +@pytest.mark.parametrize("k_lo,k_hi", [(0, 16), (0, 32), (16, 32), (16, 48), (32, 64)]) +def test_gemm_tcgen05_contiguous_kslice_partial_k(k_lo, k_hi): + """A slice on the *contiguous* (K) axis of a swizzled gemm_async operand must + compute the correct partial-K product, not silently use full K. + + The operand buffer is 128B-swizzled (contiguous atom = 64 elems for fp16) and + the gemm operand is sliced to K=[lo:hi] on that axis. The descriptor is + anchored on the buffer's physical swizzle while K_iters covers only the slice, + so the MMA accumulates exactly k in [lo, hi) -- enabling fine K-major split-K. + Any MMA_K(16)-aligned [lo:hi] is supported. + """ + from tvm.tirx.cuda.operator.tile_primitive.tma_utils import SwizzleMode + + M, N, K_alloc = 128, 128, 64 + dtype = "float16" + A_shape, B_shape, C_shape = (M, K_alloc), (N, K_alloc), (M, N) + A_layout = mma_shared_layout(dtype, SwizzleMode.SWIZZLE_128B_ATOM, A_shape) + B_layout = mma_shared_layout(dtype, SwizzleMode.SWIZZLE_128B_ATOM, B_shape) + total_bytes = (M * K_alloc + N * K_alloc) * 2 + + # fmt: off + @T.prim_func + def gemm_async(A_ptr: T.handle, B_ptr: T.handle, C_ptr: T.handle) -> None: + A = T.match_buffer(A_ptr, A_shape, dtype) + B = T.match_buffer(B_ptr, B_shape, dtype) + C = T.match_buffer(C_ptr, C_shape, "float32") + T.device_entry() + warp_id = T.warp_id([4]) + wg_id = T.warpgroup_id([1]) + tid_in_wg = T.thread_id_in_wg([128]) + A_smem = T.alloc_buffer(A_shape, dtype, scope="shared", layout=A_layout) + B_smem = T.alloc_buffer(B_shape, dtype, scope="shared", layout=B_layout) + tmem_addr = T.alloc_shared([1], "uint32") + tma_mbar = T.alloc_shared([1], "uint64") + mma_mbar = T.alloc_shared([1], "uint64") + if tid_in_wg == 0: + T.ptx.mbarrier.init(tma_mbar.ptr_to([0]), 1) + T.ptx.mbarrier.init(mma_mbar.ptr_to([0]), 1) + T.ptx.fence.proxy_async("shared::cta") + T.cuda.cta_sync() + if warp_id == 0: + T.ptx.tcgen05.alloc(T.address_of(tmem_addr), n_cols=128, cta_group=1) + T.cuda.cta_sync() + tmem = T.decl_buffer((128, N), "float32", scope="tmem", allocated_addr=tmem_addr[0], layout=TileLayout(S[(128, N) : (1 @ TLane, 1 @ TCol)])) # noqa: E501 + if tid_in_wg == 0: + tma_args = T.meta_var({"dispatch": "tma", "mbar": tma_mbar.ptr_to([0])}) + Tx.copy_async(A_smem[0:M, 0:K_alloc], A[0:M, 0:K_alloc], **tma_args) + Tx.copy_async(B_smem[0:N, 0:K_alloc], B[0:N, 0:K_alloc], **tma_args) + T.ptx.mbarrier.arrive.expect_tx(tma_mbar.ptr_to([0]), total_bytes) + T.ptx.mbarrier.try_wait(tma_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() + if tid_in_wg == 0: + # Contiguous-axis K slice [k_lo:k_hi] -> must accumulate only that K range. + Tx.gemm_async(tmem[0:128, 0:N], A_smem[0:M, k_lo:k_hi], B_smem[0:N, k_lo:k_hi], dispatch="tcgen05") # noqa: E501 + T.ptx.tcgen05.commit(mma_mbar.ptr_to([0]), cta_group=1) + T.ptx.mbarrier.try_wait(mma_mbar.ptr_to([0]), 0) + T.cuda.cta_sync() + T.ptx.tcgen05.fence.after_thread_sync() + C_reg = T.alloc_local(N, dtype="float32") + C_view = C_reg.view(128, N, layout=TileLayout(S[(128, N) : (1@axis_tid_in_wg, 1)])) + if wg_id == 0: + Tx.wg.copy_async(C_view[:, :], tmem[0:128, 0:N]) + T.ptx.tcgen05.wait.ld() + T.cuda.cta_sync() + Tx.copy(C[tid_in_wg, 0:N], C_reg[:]) + if warp_id == 0: + T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1) + T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=128, cta_group=1) + # fmt: on + + dev = tvm.cuda(0) + np.random.seed(0) + with tvm.target.Target("cuda"): + mod = tvm.compile(tvm.IRModule({"main": gemm_async}), target="cuda", tir_pipeline="tirx") + A_np = np.random.randn(*A_shape).astype(dtype) + B_np = np.random.randn(*B_shape).astype(dtype) + C_np = np.zeros(C_shape, "float32") + A_t, B_t, C_t = (tvm.runtime.tensor(x, dev) for x in (A_np, B_np, C_np)) + mod["main"](A_t, B_t, C_t) + # Reference: accumulate only k in [k_lo, k_hi). + C_ref = A_np[:, k_lo:k_hi].astype("float32") @ B_np[:, k_lo:k_hi].astype("float32").T + np.testing.assert_allclose(C_t.numpy(), C_ref, atol=1e-2, rtol=1e-2) + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py b/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py index 9aba8b4316dd..67cc1e0bd6fa 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py @@ -43,14 +43,14 @@ import tvm.testing from tvm.script import tirx as T from tvm.script.tirx import tile as Tx -from tvm.tirx.layout import S, SwizzleLayout, TileLayout # Helpers exposed by the dispatcher module for direct algorithm tests. -from tvm.tirx.operator.tile_primitive.cuda.permute_layout.warp_xor_swizzle import ( +from tvm.tirx.cuda.operator.tile_primitive.permute_layout.warp_xor_swizzle import ( _bank_free, _check_bijection, _choose_xor_k, ) +from tvm.tirx.layout import S, SwizzleLayout, TileLayout # --------------------------------------------------------------------------- # Algorithm-only tests (no CUDA needed). diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_compose_op_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_compose_op_trn.py index 448b856d9aca..b5e8a6554a64 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_compose_op_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_compose_op_trn.py @@ -86,7 +86,7 @@ def expected(): # fmt: on with target: mod = tvm.IRModule({"main": activation_reduce}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) assert_structural_equal(mod["main"], expected) @@ -128,7 +128,7 @@ def expected(): # fmt: off with target: mod = tvm.IRModule({"main": activation_reduce}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) assert_structural_equal(mod["main"], expected) @@ -170,7 +170,7 @@ def expected(): # fmt: off with target: mod = tvm.IRModule({"main": activation_reduce}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) assert_structural_equal(mod["main"], expected) @@ -218,7 +218,7 @@ def expected(): # fmt: off with target: mod = tvm.IRModule({"main": activation_reduce}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) assert_structural_equal(mod["main"], expected) @@ -401,7 +401,7 @@ def expected(): # fmt: on with target: mod = tvm.IRModule({"main": tensor_scalar_reduce}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) assert_structural_equal(mod["main"], expected) @@ -601,7 +601,7 @@ def expected(): # fmt: on with target: mod = tvm.IRModule({"main": unary_reduce}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -690,7 +690,7 @@ def expected(): # fmt: on with target: mod = tvm.IRModule({"main": activation_reduce}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) assert_structural_equal(mod["main"], expected) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_copy_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_copy_trn.py index 4be47a7ed147..dab83844bb31 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_copy_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_copy_trn.py @@ -229,7 +229,7 @@ def expected(): with target: mod = tvm.IRModule({"main": copy}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -276,7 +276,7 @@ def expected(): # fmt: on with target: mod = tvm.IRModule({"main": copy}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -759,7 +759,7 @@ def expected(): # fmt: on with target: mod = tvm.IRModule({"main": copy}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -831,7 +831,7 @@ def expected(A_ptr: T.handle): # fmt: on with target: mod = tvm.IRModule({"main": copy}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_gemm_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_gemm_trn.py index 18beb0390638..8397627a88c4 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_gemm_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_gemm_trn.py @@ -283,7 +283,7 @@ def expected(): # fmt: on with target: mod = tvm.IRModule({"main": gemm}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) @@ -530,7 +530,7 @@ def expected(): # fmt: on with target: mod = tvm.IRModule({"main": gemm}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_private_alloc_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_private_alloc_trn.py index 14c0f5dea795..e8acb3931dc4 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_private_alloc_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_private_alloc_trn.py @@ -21,7 +21,7 @@ from tvm.script import tirx as T from tvm.script.tirx import tile as Tx from tvm.tirx.layout import F, P, S, TileLayout -from tvm.tirx.transform.trn import TrnPrivateBufferAlloc +from tvm.tirx.trn.transform import TrnPrivateBufferAlloc target = tvm.target.Target("aws/trn1/trn1.2xlarge") diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_reduction_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_reduction_trn.py index ef8146b76286..36da370d10e8 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_reduction_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_reduction_trn.py @@ -188,7 +188,7 @@ def expected(): # fmt: on with target: mod = tvm.IRModule({"main": reduction}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) assert_structural_equal(mod["main"], expected) @@ -234,7 +234,7 @@ def expected(): # fmt: on with target: mod = tvm.IRModule({"main": reduction}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) mod = tvm.tirx.transform.StmtSimplify()(mod) assert_structural_equal(mod["main"], expected) diff --git a/tests/python/tirx/operator/tile_primitive/trn/test_unary_trn.py b/tests/python/tirx/operator/tile_primitive/trn/test_unary_trn.py index db6e968b36a3..a774b4c9e447 100644 --- a/tests/python/tirx/operator/tile_primitive/trn/test_unary_trn.py +++ b/tests/python/tirx/operator/tile_primitive/trn/test_unary_trn.py @@ -239,7 +239,7 @@ def expected(): # fmt: off with target: mod = tvm.IRModule({"main": unary}) - mod = tvm.tirx.transform.trn.TrnPrivateBufferAlloc()(mod) + mod = tvm.tirx.trn.transform.TrnPrivateBufferAlloc()(mod) mod = tvm.tirx.transform.LowerTIRx()(mod) assert_structural_equal(mod["main"], expected) diff --git a/tests/python/tirx/test_alloc_pool.py b/tests/python/tirx/test_alloc_pool.py index 0aadb260fa0f..41ea560ed706 100644 --- a/tests/python/tirx/test_alloc_pool.py +++ b/tests/python/tirx/test_alloc_pool.py @@ -14,12 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Tests for tvm.tirx.lang.alloc_pool validation.""" +"""Tests for CUDA allocation pool validation.""" import pytest -from tvm.tirx.lang.alloc_pool import _validate_mma_alloc_shape -from tvm.tirx.operator.tile_primitive.cuda.tma_utils import SwizzleMode +from tvm.tirx.cuda.lang.alloc_pool import _validate_mma_alloc_shape +from tvm.tirx.cuda.operator.tile_primitive.tma_utils import SwizzleMode # --------------------------------------------------------------------------- # alloc_mma shape validation: bad inputs raise actionable ValueError instead of diff --git a/tests/python/tirx/test_layout.py b/tests/python/tirx/test_layout.py index 1666d616e663..e3711cb00cd2 100644 --- a/tests/python/tirx/test_layout.py +++ b/tests/python/tirx/test_layout.py @@ -29,6 +29,11 @@ from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import tirx as Tx_builder from tvm.tirx import Var +from tvm.tirx.cuda.operator.tile_primitive.tma_utils import ( + SwizzleMode, + mma_shared_layout, + tma_shared_layout, +) from tvm.tirx.layout import ( Axis, ComposeLayout, @@ -48,11 +53,6 @@ wgid, wid_in_wg, ) -from tvm.tirx.operator.tile_primitive.cuda.tma_utils import ( - SwizzleMode, - mma_shared_layout, - tma_shared_layout, -) def test_axis(): diff --git a/tests/python/tirx/test_op_namespace_cleanup.py b/tests/python/tirx/test_op_namespace_cleanup.py index 0bbfcff3e86d..40965d339b46 100644 --- a/tests/python/tirx/test_op_namespace_cleanup.py +++ b/tests/python/tirx/test_op_namespace_cleanup.py @@ -16,6 +16,10 @@ # under the License. """Tests for TIRx op namespace split between T, T.tile, and device namespaces.""" +import importlib +import sys +import types + import pytest import tvm @@ -144,6 +148,30 @@ def tile_aliases(a: T.handle, b: T.handle): def test_device_intrinsic_namespaces_are_canonical_and_classified(): + from tvm.backend.cuda.script import ( + CUDANamespace as BackendCUDANamespace, + ) + from tvm.backend.cuda.script import ( + NVSHMEMNamespace as BackendNVSHMEMNamespace, + ) + from tvm.backend.cuda.script import ( + PTXNamespace as BackendPTXNamespace, + ) + from tvm.backend.metal.script import MetalNamespace as BackendMetalNamespace + from tvm.backend.trn.script import NKINamespace as BackendNKINamespace + from tvm.tirx.script.builder import ir as builder_ir + + assert isinstance(builder_ir.cuda, BackendCUDANamespace) + assert isinstance(builder_ir.ptx, BackendPTXNamespace) + assert isinstance(builder_ir.nvshmem, BackendNVSHMEMNamespace) + assert isinstance(builder_ir.metal, BackendMetalNamespace) + assert isinstance(builder_ir.nki, BackendNKINamespace) + assert T.cuda is builder_ir.cuda + assert T.ptx is builder_ir.ptx + assert T.nvshmem is builder_ir.nvshmem + assert T.metal is builder_ir.metal + assert T.nki is builder_ir.nki + buffer = tvm.tirx.decl_buffer((1,), "float32") calls = [ T.ptx.elect_sync(), @@ -166,6 +194,99 @@ def test_device_intrinsic_namespaces_are_canonical_and_classified(): assert _op_attr(op_name, "TDeviceIntrinsicNamespace") == namespace +def test_backend_specific_wrappers_are_not_root_exports(): + from tvm.tirx.cuda import op as cuda_op + from tvm.tirx.metal import op as metal_op + from tvm.tirx.trn import op as trn_op + + backend_only_names = [ + "ptx_mma", + "mma_store", + "cuda_thread_fence", + "nvshmem_fence", + "make_filled_simdgroup_matrix", + "simdgroup_load", + "nki_load", + ] + for name in backend_only_names: + assert not hasattr(tvm.tirx.op, name) + assert not hasattr(tvm.tirx, name) + assert not hasattr(T, name) + + assert cuda_op.ptx_mma + assert cuda_op.mma_store + assert cuda_op.cuda_thread_fence + assert cuda_op.nvshmem_fence + assert metal_op.make_filled_simdgroup_matrix + assert metal_op.simdgroup_load + assert trn_op.nki_load + assert hasattr(T, "cuda") + assert hasattr(T, "ptx") + assert hasattr(T, "nvshmem") + assert hasattr(T, "metal") + assert hasattr(T, "nki") + + +def test_backend_load_updates_tirx_alias_and_script_facades(monkeypatch): + from tvm.tirx.script import builder, parser + from tvm.tirx.script.builder import ir as builder_ir + + backend_name = "unit_test_backend" + backend_module_name = f"tvm.backend.{backend_name}" + public_module_name = f"tvm.tirx.{backend_name}" + public_op_module_name = f"{public_module_name}.op" + namespace_name = "unit_test_backend_ns" + register_calls = [] + + class UnitTestNamespace: + pass + + module = types.ModuleType(backend_module_name) + module.__path__ = [] + module.__package__ = backend_module_name + op_module = types.ModuleType(f"{backend_module_name}.op") + op_module.marker = object() + + def register_backend(): + register_calls.append(True) + builder_ir.register_script_namespace(namespace_name, UnitTestNamespace()) + + module.register_backend = register_backend + monkeypatch.setitem(sys.modules, backend_module_name, module) + monkeypatch.setitem(sys.modules, op_module.__name__, op_module) + sys.modules.pop(public_module_name, None) + sys.modules.pop(public_op_module_name, None) + tvm.backend._LOADED_BACKENDS.pop(backend_name, None) + if hasattr(tvm.tirx, backend_name): + delattr(tvm.tirx, backend_name) + + with pytest.raises(ModuleNotFoundError): + importlib.import_module(public_op_module_name) + + try: + assert tvm.backend.load(backend_name) is None + assert tvm.backend.load(backend_name) is None + assert register_calls == [True] + assert tvm.backend.is_loaded(backend_name) + assert getattr(tvm.tirx, backend_name) is module + assert sys.modules[public_module_name] is module + public_op_module = importlib.import_module(public_op_module_name) + assert public_op_module.__tvm_backend_module__ is op_module + assert public_op_module.marker is op_module.marker + + namespace = getattr(builder_ir, namespace_name) + assert isinstance(namespace, UnitTestNamespace) + assert getattr(builder, namespace_name) is namespace + assert getattr(parser, namespace_name) is namespace + assert getattr(T, namespace_name) is namespace + finally: + tvm.backend._LOADED_BACKENDS.pop(backend_name, None) + if hasattr(tvm.tirx, backend_name): + delattr(tvm.tirx, backend_name) + sys.modules.pop(public_module_name, None) + sys.modules.pop(public_op_module_name, None) + + def test_device_intrinsic_printer_roundtrips_canonical_namespaces(): @T.prim_func def device_namespaces(dst: T.handle, src: T.handle): diff --git a/tests/python/tirx/test_printer_tir_namespaces.py b/tests/python/tirx/test_printer_tir_namespaces.py index c79d700c8e01..57c989bd4a32 100644 --- a/tests/python/tirx/test_printer_tir_namespaces.py +++ b/tests/python/tirx/test_printer_tir_namespaces.py @@ -19,6 +19,8 @@ import tvm from tvm import tirx as tir from tvm.script import tirx as T +from tvm.tirx.cuda import op as cuda_op +from tvm.tirx.trn import op as trn_op def _assert_print(obj, expected): @@ -28,27 +30,27 @@ def _assert_print(obj, expected): def test_printer_cuda_namespace_printf(): - node = tir.Evaluate(tir.op.cuda_printf("x=%d", tir.IntImm("int32", 1))) + node = tir.Evaluate(cuda_op.cuda_printf("x=%d", tir.IntImm("int32", 1))) _assert_print(node, 'T.cuda.printf("x=%d", 1)') def test_printer_ptx_namespace_wgmma_commit_group(): - node = tir.Evaluate(tir.op.ptx_wgmma_commit_group()) + node = tir.Evaluate(cuda_op.ptx_wgmma_commit_group()) _assert_print(node, "T.ptx.wgmma.commit_group()") def test_printer_cuda_cluster_sync(): - node = tir.Evaluate(tir.op.cuda_cluster_sync()) + node = tir.Evaluate(cuda_op.cuda_cluster_sync()) _assert_print(node, "T.cuda.cluster_sync()") def test_printer_ptx_namespace_cp_async_wait_group(): - node = tir.Evaluate(tir.op.ptx_cp_async_wait_group(tir.IntImm("int32", 0))) + node = tir.Evaluate(cuda_op.ptx_cp_async_wait_group(tir.IntImm("int32", 0))) _assert_print(node, "T.ptx.cp_async.wait_group(0)") def test_printer_nvshmem_namespace(): - node = tir.Evaluate(tir.op.nvshmem_fence()) + node = tir.Evaluate(cuda_op.nvshmem_fence()) _assert_print(node, "T.nvshmem.fence()") @@ -58,66 +60,66 @@ def test_printer_ptx_more(): _assert_print( # New API: (trans, num, dtype, smem_ptr, *dst_handles). # .x1.b16 has 1 dst register, so 1 dst handle. - tir.op.ptx_ldmatrix(True, 1, ".b16", s, r), + cuda_op.ptx_ldmatrix(True, 1, ".b16", s, r), 's = T.handle()\nr = T.handle()\nT.ptx.ldmatrix(T.bool(True), 1, ".b16", s, r)', ) _assert_print( # New API: (trans, num, dtype, smem_ptr, *src_handles). # .x1.b16 has 1 src register, so 1 src handle. - tir.op.ptx_stmatrix(False, 1, ".b16", s, r), + cuda_op.ptx_stmatrix(False, 1, ".b16", s, r), ( "s = T.handle()\nr = T.handle()\nT.ptx.stmatrix(" 'T.bool(False), 1, ".b16", "m8n8", "shared", s, r)' ), ) - _assert_print(tir.op.ptx_setmaxnreg(True, 64), "T.ptx.setmaxnreg(T.bool(True), 64)") - _assert_print(tir.op.ptx_fetch_register(32, "laneid"), 'T.ptx.fetch_register(32, "laneid")') - _assert_print(tir.op.ptx_wgmma_fence(), "T.ptx.wgmma.fence()") - _assert_print(tir.op.ptx_wgmma_wait_group(0), "T.ptx.wgmma.wait_group(0)") - _assert_print(tir.op.ptx_cp_async_commit_group(), "T.ptx.cp_async.commit_group()") - _assert_print(tir.op.ptx_cp_async_bulk_commit_group(), "T.ptx.cp_async.bulk.commit_group()") + _assert_print(cuda_op.ptx_setmaxnreg(True, 64), "T.ptx.setmaxnreg(T.bool(True), 64)") + _assert_print(cuda_op.ptx_fetch_register(32, "laneid"), 'T.ptx.fetch_register(32, "laneid")') + _assert_print(cuda_op.ptx_wgmma_fence(), "T.ptx.wgmma.fence()") + _assert_print(cuda_op.ptx_wgmma_wait_group(0), "T.ptx.wgmma.wait_group(0)") + _assert_print(cuda_op.ptx_cp_async_commit_group(), "T.ptx.cp_async.commit_group()") + _assert_print(cuda_op.ptx_cp_async_bulk_commit_group(), "T.ptx.cp_async.bulk.commit_group()") _assert_print( - tir.op.ptx_cp_async_bulk_wait_group(0, True), + cuda_op.ptx_cp_async_bulk_wait_group(0, True), "T.ptx.cp_async.bulk.wait_group(0, T.bool(True))", ) - _assert_print(tir.op.ptx_cp_async_mbarrier_arrive(0), "T.ptx.cp_async.mbarrier.arrive(0)") - _assert_print(tir.op.ptx_fence("acq_rel", "gpu"), 'T.ptx.fence("acq_rel", "gpu")') - _assert_print(tir.op.ptx_fence("sc", "cta"), 'T.ptx.fence("sc", "cta")') + _assert_print(cuda_op.ptx_cp_async_mbarrier_arrive(0), "T.ptx.cp_async.mbarrier.arrive(0)") + _assert_print(cuda_op.ptx_fence("acq_rel", "gpu"), 'T.ptx.fence("acq_rel", "gpu")') + _assert_print(cuda_op.ptx_fence("sc", "cta"), 'T.ptx.fence("sc", "cta")') _assert_print( - tir.op.ptx_fence_proxy_async("shared::cta"), 'T.ptx.fence.proxy_async("shared::cta")' + cuda_op.ptx_fence_proxy_async("shared::cta"), 'T.ptx.fence.proxy_async("shared::cta")' ) - _assert_print(tir.op.ptx_fence_proxy_async("global"), 'T.ptx.fence.proxy_async("global")') - _assert_print(tir.op.ptx_fence_mbarrier_init(), "T.ptx.fence.mbarrier_init()") - _assert_print(tir.op.ptx_elect_sync(), "T.ptx.elect_sync()") + _assert_print(cuda_op.ptx_fence_proxy_async("global"), 'T.ptx.fence.proxy_async("global")') + _assert_print(cuda_op.ptx_fence_mbarrier_init(), "T.ptx.fence.mbarrier_init()") + _assert_print(cuda_op.ptx_elect_sync(), "T.ptx.elect_sync()") lane = tir.Var("lane", "int32") _assert_print( - tir.op.selector(lane, tir.op.ptx_elect_sync()), + tir.op.selector(lane, cuda_op.ptx_elect_sync()), "lane = T.int32()\nT.selector(lane, T.ptx.elect_sync())", ) _assert_print( - tir.op.ptx_ld_global_acquire(r, s), + cuda_op.ptx_ld_global_acquire(r, s), "r = T.handle()\ns = T.handle()\nT.ptx.ld_global_acquire(r, s)", ) _assert_print( - tir.op.ptx_map_shared_rank(r, 2), 'r = T.handle()\nT.ptx.mapa(r, 2, "", "u64", "uint64")' + cuda_op.ptx_map_shared_rank(r, 2), 'r = T.handle()\nT.ptx.mapa(r, 2, "", "u64", "uint64")' ) - _assert_print(tir.op.ptx_bar_arrive(0, 128), "T.ptx.bar.arrive(0, 128)") - _assert_print(tir.op.ptx_bar_sync(0, 128), "T.ptx.bar.sync(0, 128)") + _assert_print(cuda_op.ptx_bar_arrive(0, 128), "T.ptx.bar.arrive(0, 128)") + _assert_print(cuda_op.ptx_bar_sync(0, 128), "T.ptx.bar.sync(0, 128)") _assert_print( - tir.op.ptx_tcgen05_alloc(s, 64, 1), "s = T.handle()\nT.ptx.tcgen05.alloc(s, 64, 1)" + cuda_op.ptx_tcgen05_alloc(s, 64, 1), "s = T.handle()\nT.ptx.tcgen05.alloc(s, 64, 1)" ) _assert_print( - tir.op.ptx_tcgen05_dealloc(s, 64, 1), "s = T.handle()\nT.ptx.tcgen05.dealloc(s, 64, 1)" + cuda_op.ptx_tcgen05_dealloc(s, 64, 1), "s = T.handle()\nT.ptx.tcgen05.dealloc(s, 64, 1)" ) d = tir.Var("d", "handle") a = tir.Var("a", "handle") b = tir.Var("b", "handle") _assert_print( - tir.op.ptx_tcgen05_encode_matrix_descriptor(d, a, 1, 2, 0), + cuda_op.ptx_tcgen05_encode_matrix_descriptor(d, a, 1, 2, 0), "d = T.handle()\na = T.handle()\nT.ptx.tcgen05.encode_matrix_descriptor(d, a, 1, 2, 0)", ) _assert_print( - tir.op.ptx_tcgen05_encode_instr_descriptor( + cuda_op.ptx_tcgen05_encode_instr_descriptor( d, d_dtype="f16", a_dtype="f16", @@ -136,7 +138,7 @@ def test_printer_ptx_more(): 'd = T.handle()\nT.ptx.tcgen05.encode_instr_descriptor(d, "f16", "f16", "f16", 16, 16, 16, T.bool(True), T.bool(False), 1, T.bool(False), T.bool(False), T.bool(False), T.bool(False))', # noqa: E501 ) _assert_print( - tir.op.ptx_tcgen05_encode_instr_descriptor_block_scaled( + cuda_op.ptx_tcgen05_encode_instr_descriptor_block_scaled( d, d_dtype="f16", a_dtype="f16", @@ -161,64 +163,64 @@ def test_printer_ptx_more(): 'T.ptx.tcgen05.encode_instr_descriptor_block_scaled(d, "f16", "f16", "f16", "f16", "f16", a, b, 16, 16, 16, T.bool(True), T.bool(False), 1, T.bool(False), T.bool(False), T.bool(True))', # noqa: E501 ) _assert_print( - tir.op.ptx_tcgen05_cp(a, d, shape="64x128b", cta_group=1, multicast="warpx2::02_13"), + cuda_op.ptx_tcgen05_cp(a, d, shape="64x128b", cta_group=1, multicast="warpx2::02_13"), "a = T.handle()\n" "d = T.handle()\n" 'T.ptx.tcgen05.cp(a, d, "64x128b", 1, "warpx2::02_13", "", 0, 0)', ) - _assert_print(tir.op.ptx_tcgen05_shift(a, 1), "a = T.handle()\nT.ptx.tcgen05.shift(a, 1)") + _assert_print(cuda_op.ptx_tcgen05_shift(a, 1), "a = T.handle()\nT.ptx.tcgen05.shift(a, 1)") _assert_print( - tir.op.ptx_tcgen05_ld(a, 0, shape="16x64b", num=1, row=0, col=0, pack=False), + cuda_op.ptx_tcgen05_ld(a, 0, shape="16x64b", num=1, row=0, col=0, pack=False), 'a = T.handle()\nT.ptx.tcgen05.ld(a, 0, 0, "16x64b", 1, T.bool(False), 0)', ) _assert_print( - tir.op.ptx_tcgen05_st(a, 0, shape="16x64b", num=1, row=0, col=0, unpack=False), + cuda_op.ptx_tcgen05_st(a, 0, shape="16x64b", num=1, row=0, col=0, unpack=False), 'a = T.handle()\nT.ptx.tcgen05.st(a, 0, 0, "16x64b", 1, T.bool(False), 0)', ) - _assert_print(tir.op.ptx_tcgen05_wait_ld(), "T.ptx.tcgen05.wait.ld()") - _assert_print(tir.op.ptx_tcgen05_wait_st(), "T.ptx.tcgen05.wait.st()") + _assert_print(cuda_op.ptx_tcgen05_wait_ld(), "T.ptx.tcgen05.wait.ld()") + _assert_print(cuda_op.ptx_tcgen05_wait_st(), "T.ptx.tcgen05.wait.st()") _assert_print( - tir.op.ptx_tcgen05_commit(a, 1, 0), "a = T.handle()\nT.ptx.tcgen05.commit(a, 1, 0)" + cuda_op.ptx_tcgen05_commit(a, 1, 0), "a = T.handle()\nT.ptx.tcgen05.commit(a, 1, 0)" ) _assert_print( - tir.op.ptx_tcgen05_relinquish_alloc_permit(1), "T.ptx.tcgen05.relinquish_alloc_permit(1)" + cuda_op.ptx_tcgen05_relinquish_alloc_permit(1), "T.ptx.tcgen05.relinquish_alloc_permit(1)" ) def test_printer_ptx_mbarrier(): bar = tir.Var("bar", "handle") _assert_print( - tir.op.ptx_mbarrier_init(bar, 32), "bar = T.handle()\nT.ptx.mbarrier.init(bar, 32)" + cuda_op.ptx_mbarrier_init(bar, 32), "bar = T.handle()\nT.ptx.mbarrier.init(bar, 32)" ) - _assert_print(tir.op.ptx_mbarrier_arrive(bar), "bar = T.handle()\nT.ptx.mbarrier.arrive(bar)") + _assert_print(cuda_op.ptx_mbarrier_arrive(bar), "bar = T.handle()\nT.ptx.mbarrier.arrive(bar)") _assert_print( - tir.op.ptx_mbarrier_arrive_expect_tx(bar, 128), + cuda_op.ptx_mbarrier_arrive_expect_tx(bar, 128), "bar = T.handle()\nT.ptx.mbarrier.arrive.expect_tx(bar, 128)", ) _assert_print( - tir.op.ptx_mbarrier_try_wait(bar, 1), "bar = T.handle()\nT.ptx.mbarrier.try_wait(bar, 1)" + cuda_op.ptx_mbarrier_try_wait(bar, 1), "bar = T.handle()\nT.ptx.mbarrier.try_wait(bar, 1)" ) - _assert_print(tir.op.cuda_cluster_sync(), "T.cuda.cluster_sync()") + _assert_print(cuda_op.cuda_cluster_sync(), "T.cuda.cluster_sync()") def test_printer_cuda_more(): p = tir.Var("p", "handle") - _assert_print(tir.op.cuda_thread_fence(), "T.cuda.thread_fence()") - _assert_print(tir.op.cuda_warp_sync(), "T.cuda.warp_sync()") - _assert_print(tir.op.cuda_cta_sync(), "T.cuda.cta_sync()") - _assert_print(tir.op.cuda_grid_sync(), "T.cuda.grid_sync()") - _assert_print(tir.op.cuda_cluster_sync(), "T.cuda.cluster_sync()") - _assert_print(tir.op.cuda_syncthreads_and(1), "T.cuda.syncthreads_and(1)") - _assert_print(tir.op.cuda_syncthreads_or(1), "T.cuda.syncthreads_or(1)") - _assert_print(tir.op.cuda_nano_sleep(100), "T.cuda.nano_sleep(100)") - _assert_print( - tir.op.cuda_atomic_add(p, tir.IntImm("int32", 1)), + _assert_print(cuda_op.cuda_thread_fence(), "T.cuda.thread_fence()") + _assert_print(cuda_op.cuda_warp_sync(), "T.cuda.warp_sync()") + _assert_print(cuda_op.cuda_cta_sync(), "T.cuda.cta_sync()") + _assert_print(cuda_op.cuda_grid_sync(), "T.cuda.grid_sync()") + _assert_print(cuda_op.cuda_cluster_sync(), "T.cuda.cluster_sync()") + _assert_print(cuda_op.cuda_syncthreads_and(1), "T.cuda.syncthreads_and(1)") + _assert_print(cuda_op.cuda_syncthreads_or(1), "T.cuda.syncthreads_or(1)") + _assert_print(cuda_op.cuda_nano_sleep(100), "T.cuda.nano_sleep(100)") + _assert_print( + cuda_op.cuda_atomic_add(p, tir.IntImm("int32", 1)), "p = T.handle()\nT.cuda.atomic_add(p, 1)", ) - _assert_print(tir.op.cuda_atomic_cas(p, 1, 2), "p = T.handle()\nT.cuda.atomic_cas(p, 1, 2)") - _assert_print(tir.op.cuda_ldg(p, "float32"), 'p = T.handle()\nT.cuda.ldg(p, "float32")') + _assert_print(cuda_op.cuda_atomic_cas(p, 1, 2), "p = T.handle()\nT.cuda.atomic_cas(p, 1, 2)") + _assert_print(cuda_op.cuda_ldg(p, "float32"), 'p = T.handle()\nT.cuda.ldg(p, "float32")') _assert_print( - tir.op.cuda_func_call("f", 1, source_code=""), 'T.cuda.func_call("f", 1, source_code="")' + cuda_op.cuda_func_call("f", 1, source_code=""), 'T.cuda.func_call("f", 1, source_code="")' ) @@ -260,48 +262,48 @@ def kernel(): def test_printer_nvshmem_more(): p = tir.Var("p", "handle") - _assert_print(tir.op.nvshmem_my_pe(), "T.nvshmem.my_pe()") - _assert_print(tir.op.nvshmem_n_pes(), "T.nvshmem.n_pes()") + _assert_print(cuda_op.nvshmem_my_pe(), "T.nvshmem.my_pe()") + _assert_print(cuda_op.nvshmem_n_pes(), "T.nvshmem.n_pes()") _assert_print( - tir.op.nvshmem_signal_op(p, 1, "set", 0), + cuda_op.nvshmem_signal_op(p, 1, "set", 0), 'p = T.handle()\nT.nvshmem.signal_op(p, 1, "set", 0)', ) _assert_print( - tir.op.nvshmem_wait_until(p, "eq", 0), + cuda_op.nvshmem_wait_until(p, "eq", 0), 'p = T.handle()\nT.nvshmem.wait_until(p, "eq", 0, "uint64_t")', ) - _assert_print(tir.op.nvshmem_quiet(), "T.nvshmem.quiet()") - _assert_print(tir.op.nvshmem_barrier_all(), "T.nvshmem.barrier_all()") + _assert_print(cuda_op.nvshmem_quiet(), "T.nvshmem.quiet()") + _assert_print(cuda_op.nvshmem_barrier_all(), "T.nvshmem.barrier_all()") _assert_print( - tir.op.nvshmem_getmem_nbi(p, p, 16, 0), + cuda_op.nvshmem_getmem_nbi(p, p, 16, 0), "p = T.handle()\nT.nvshmem.getmem_nbi(p, p, 16, 0)", ) _assert_print( - tir.op.nvshmem_getmem_nbi_warp(p, p, 16, 0), + cuda_op.nvshmem_getmem_nbi_warp(p, p, 16, 0), "p = T.handle()\nT.nvshmem.getmem_nbi.warp(p, p, 16, 0)", ) _assert_print( - tir.op.nvshmem_putmem_nbi_block(p, p, 16, 0), + cuda_op.nvshmem_putmem_nbi_block(p, p, 16, 0), "p = T.handle()\nT.nvshmem.putmem_nbi.block(p, p, 16, 0)", ) _assert_print( - tir.op.nvshmem_putmem_nbi(p, p, 16, 0), + cuda_op.nvshmem_putmem_nbi(p, p, 16, 0), "p = T.handle()\nT.nvshmem.putmem_nbi(p, p, 16, 0)", ) _assert_print( - tir.op.nvshmem_putmem_nbi_warp(p, p, 16, 0), + cuda_op.nvshmem_putmem_nbi_warp(p, p, 16, 0), "p = T.handle()\nT.nvshmem.putmem_nbi.warp(p, p, 16, 0)", ) _assert_print( - tir.op.nvshmem_putmem_signal_nbi(p, p, 16, p, 1, "set", 0), + cuda_op.nvshmem_putmem_signal_nbi(p, p, 16, p, 1, "set", 0), 'p = T.handle()\nT.nvshmem.putmem_signal_nbi(p, p, 16, p, 1, "set", 0)', ) _assert_print( - tir.op.nvshmem_putmem_signal_nbi_warp(p, p, 16, p, 1, "set", 0), + cuda_op.nvshmem_putmem_signal_nbi_warp(p, p, 16, p, 1, "set", 0), 'p = T.handle()\nT.nvshmem.putmem_signal_nbi.warp(p, p, 16, p, 1, "set", 0)', ) _assert_print( - tir.op.nvshmem_putmem_signal_nbi_block(p, p, 16, p, 1, "set", 0), + cuda_op.nvshmem_putmem_signal_nbi_block(p, p, 16, p, 1, "set", 0), 'p = T.handle()\nT.nvshmem.putmem_signal_nbi.block(p, p, 16, p, 1, "set", 0)', ) @@ -312,81 +314,81 @@ def test_printer_nki_namespace(): a0 = A[0] b0 = B[0] _assert_print( - tir.op.nki_load(a0, b0), + trn_op.nki_load(a0, b0), 'A = T.Buffer((1,), "float16")\nB = T.Buffer((1,), "float16")\nT.nki.load(A, B)', ) _assert_print( - tir.op.nki_store(a0, b0), + trn_op.nki_store(a0, b0), 'A = T.Buffer((1,), "float16")\nB = T.Buffer((1,), "float16")\nT.nki.store(A, B)', ) _assert_print( - tir.op.nki_tensor_copy(a0, b0), + trn_op.nki_tensor_copy(a0, b0), 'A = T.Buffer((1,), "float16")\nB = T.Buffer((1,), "float16")\nT.nki.tensor_copy(A, B)', ) _assert_print( - tir.op.nki_matmul(a0, a0, b0), + trn_op.nki_matmul(a0, a0, b0), 'A = T.Buffer((1,), "float16")\n' 'B = T.Buffer((1,), "float16")\n' "T.nki.matmul(A, A, B, T.bool(True))", ) _assert_print( - tir.op.nki_activation(a0, b0, "relu", 0.0, 1.0), + trn_op.nki_activation(a0, b0, "relu", 0.0, 1.0), 'A = T.Buffer((1,), "float16")\n' 'B = T.Buffer((1,), "float16")\n' 'T.nki.activation(A, B, "relu", T.float32(0.0), T.float32(1.0))', ) _assert_print( - tir.op.nki_memset(a0, 0), + trn_op.nki_memset(a0, 0), 'A = T.Buffer((1,), "float16")\nT.nki.memset(A, 0)', ) _assert_print( - tir.op.nki_identity(a0, 1), + trn_op.nki_identity(a0, 1), 'A = T.Buffer((1,), "float16")\nT.nki.identity(A, 1)', ) _assert_print( - tir.op.nki_reciprocal(a0, b0), + trn_op.nki_reciprocal(a0, b0), 'A = T.Buffer((1,), "float16")\nB = T.Buffer((1,), "float16")\nT.nki.reciprocal(A, B)', ) _assert_print( - tir.op.nki_tensorreduce(a0, b0, "sum", False, 0), + trn_op.nki_tensorreduce(a0, b0, "sum", False, 0), 'A = T.Buffer((1,), "float16")\n' 'B = T.Buffer((1,), "float16")\n' 'T.nki.tensorreduce(A, B, "sum", T.bool(False), 0)', ) _assert_print( - tir.op.nki_tensortensor(a0, a0, b0, "add"), + trn_op.nki_tensortensor(a0, a0, b0, "add"), 'A = T.Buffer((1,), "float16")\n' 'B = T.Buffer((1,), "float16")\n' 'T.nki.tensortensor(A, A, B, "add")', ) _assert_print( - tir.op.nki_tensorscalar(a0, a0, 1.0, "mul", False), + trn_op.nki_tensorscalar(a0, a0, 1.0, "mul", False), 'A = T.Buffer((1,), "float16")\n' 'T.nki.tensorscalar(A, A, T.float32(1.0), "mul", T.bool(False))', ) _assert_print( - tir.op.nki_tensorscalar_reduce(a0, a0, 1.0, "mul", "sum", False), + trn_op.nki_tensorscalar_reduce(a0, a0, 1.0, "mul", "sum", False), 'A = T.Buffer((1,), "float16")\n' 'T.nki.tensorscalar_reduce(A, A, T.float32(1.0), "mul", "sum", T.bool(False), T.bool(False))', # noqa: E501 ) _assert_print( - tir.op.nki_scalar_tensor_tensor(a0, a0, 1.0, a0, "add", "add"), + trn_op.nki_scalar_tensor_tensor(a0, a0, 1.0, a0, "add", "add"), 'A = T.Buffer((1,), "float16")\n' 'T.nki.scalar_tensor_tensor(A, A, T.float32(1.0), A, "add", "add", T.bool(False), T.bool(False))', # noqa: E501 ) _assert_print( - tir.op.nki_scalar_tensor_scalar(a0, a0, 1.0, 1.0, "add", "add"), + trn_op.nki_scalar_tensor_scalar(a0, a0, 1.0, 1.0, "add", "add"), 'A = T.Buffer((1,), "float16")\n' 'T.nki.scalar_tensor_scalar(A, A, T.float32(1.0), T.float32(1.0), "add", "add", T.bool(False), T.bool(False))', # noqa: E501 ) _assert_print( - tir.op.nki_activation_reduce(a0, a0, b0, "relu", "sum", 0.0, 1.0), + trn_op.nki_activation_reduce(a0, a0, b0, "relu", "sum", 0.0, 1.0), 'A = T.Buffer((1,), "float16")\n' 'B = T.Buffer((1,), "float16")\n' 'T.nki.activation_reduce(A, A, B, "relu", "sum", T.float32(0.0), T.float32(1.0))', ) _assert_print( - tir.op.nki_affine_select(a0, a0, a0, 1.0), + trn_op.nki_affine_select(a0, a0, a0, 1.0), 'A = T.Buffer((1,), "float16")\nT.nki.affine_select(A, A, A, T.float32(1.0))', ) @@ -397,16 +399,16 @@ def test_printer_ptx_mma_and_wgmma(): a = tir.Var("a", "handle") tir.Var("b", "handle") _assert_print( - tir.op.ptx_mma("m8n8k4", "row", "row", "fp16", "fp16", "fp16", "fp16", [r], [r], [r]), + cuda_op.ptx_mma("m8n8k4", "row", "row", "fp16", "fp16", "fp16", "fp16", [r], [r], [r]), 'r = T.handle()\nT.ptx.mma("m8n8k4", "row", "row", "fp16", "fp16", "fp16", "fp16", 1, 1, 1, 0, T.bool(True), r, r, r, T.bool(False))', # noqa: E501 ) _assert_print( - tir.op.ptx_wgmma_encode_matrix_descriptor(d, a, 1, 1, 0), + cuda_op.ptx_wgmma_encode_matrix_descriptor(d, a, 1, 1, 0), "d = T.handle()\na = T.handle()\nT.ptx.wgmma.encode_matrix_descriptor(d, a, 1, 1, 0)", ) - _assert_print(tir.op.ptx_wgmma_noop_barrier(0), "T.ptx.wgmma.noop_barrier(0)") + _assert_print(cuda_op.ptx_wgmma_noop_barrier(0), "T.ptx.wgmma.noop_barrier(0)") _assert_print( - tir.op.ptx_wgmma_mma_async_ss( + cuda_op.ptx_wgmma_mma_async_ss( d, d, 0, @@ -425,7 +427,7 @@ def test_printer_ptx_mma_and_wgmma(): 'd = T.handle()\nT.ptx.wgmma.mma_async.ss(16, 16, 16, "f16", "f16", T.bool(True), T.bool(False), T.float32(1.0), T.float32(1.0), T.bool(True), d, d, 0, 0)', # noqa: E501 ) _assert_print( - tir.op.ptx_wgmma_mma_async_rs( + cuda_op.ptx_wgmma_mma_async_rs( d, 0, 0, @@ -447,12 +449,12 @@ def test_printer_ptx_mma_and_wgmma(): def test_printer_ptx_cp_async_tensor(): tmap = tir.Var("tm", "handle") _assert_print( - tir.op.ptx_cp_async_bulk_tensor_global_to_cluster(2, tmap, 0, tmap, 0, 1, "", 0, 1, ""), + cuda_op.ptx_cp_async_bulk_tensor_global_to_cluster(2, tmap, 0, tmap, 0, 1, "", 0, 1, ""), "tm = T.handle()\n" 'T.ptx.cp_async.bulk.tensor.g2c(2, tm, 0, tm, 0, 1, T.uint64(0), 0, 0, 1, "")', ) _assert_print( - tir.op.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster( + cuda_op.ptx_cp_async_bulk_tensor_tile_gather4_global_to_cluster( 2, tmap, 0, tmap, 0, 1, "", 0, 1, "" ), "tm = T.handle()\n" @@ -460,15 +462,15 @@ def test_printer_ptx_cp_async_tensor(): '(2, tm, 0, tm, 0, 1, T.uint64(0), 0, 0, 1, "")', ) _assert_print( - tir.op.ptx_cp_async_bulk_tensor_global_to_cluster_prefetch(2, tmap, "", 0, 0, ""), + cuda_op.ptx_cp_async_bulk_tensor_global_to_cluster_prefetch(2, tmap, "", 0, 0, ""), 'tm = T.handle()\nT.ptx.cp_async.bulk.tensor.g2c_prefetch(2, tm, T.uint64(0), 0, 0, 0, "")', ) _assert_print( - tir.op.ptx_cp_async_bulk_tensor_shared_to_global(2, 0, tmap, "", 0, 0, ""), + cuda_op.ptx_cp_async_bulk_tensor_shared_to_global(2, 0, tmap, "", 0, 0, ""), 'tm = T.handle()\nT.ptx.cp_async.bulk.tensor.s2g(2, 0, tm, T.uint64(0), 0, 0, 0, "")', ) _assert_print( - tir.op.ptx_cp_async_bulk_tensor_shared_to_global_reduce(2, 0, tmap, "", "add", 0, 0, ""), + cuda_op.ptx_cp_async_bulk_tensor_shared_to_global_reduce(2, 0, tmap, "", "add", 0, 0, ""), "tm = T.handle()\n" "T.ptx.cp_async.bulk.tensor.s2g_reduce" '(2, 0, tm, T.uint64(0), 0, "add", 0, 0, "")', @@ -479,7 +481,7 @@ def test_printer_ptx_cp_async_call(): sh = tir.Var("sh", "handle") gl = tir.Var("gl", "handle") _assert_print( - tir.op.ptx_cp_async( + cuda_op.ptx_cp_async( sh, gl, 16, cache_hint="", prefetch_size=-1, predicate=-1, fill_mode="" ), 'sh = T.handle()\ngl = T.handle()\nT.ptx.cp_async(sh, gl, 16, T.uint64(0), 0, -1, -1, "")', diff --git a/tests/python/tirx/transform/test_transform_naive_allocator.py b/tests/python/tirx/transform/test_transform_naive_allocator.py index 7d77c6114c00..221bf0e09352 100644 --- a/tests/python/tirx/transform/test_transform_naive_allocator.py +++ b/tests/python/tirx/transform/test_transform_naive_allocator.py @@ -21,7 +21,7 @@ from tvm.script import tirx as T from tvm.script.tirx import tile as Tx from tvm.tirx.layout import F, P, S, TileLayout -from tvm.tirx.transform.trn import TrnNaiveAllocator +from tvm.tirx.trn.transform import TrnNaiveAllocator def test_one_alloc():