From d39af58aaeb9df70f9129bcd09e1d4089f63feba Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 14 Jun 2026 01:57:24 +0000 Subject: [PATCH] [REFACTOR][DataType] Phase out target custom datatype support --- CMakeLists.txt | 6 - cmake/modules/contrib/Posit.cmake | 26 -- docker/Dockerfile.ci_cpu | 4 - docker/Dockerfile.ci_gpu | 4 - include/tvm/tirx/op.h | 7 - include/tvm/tirx/transform.h | 9 - python/tvm/s_tir/pipeline.py | 2 - python/tvm/target/__init__.py | 1 - python/tvm/target/datatype.py | 379 ------------------- python/tvm/tirx/compilation_pipeline.py | 2 - python/tvm/tirx/transform/transform.py | 13 - src/arith/rewrite_simplify.cc | 1 - src/target/datatype/myfloat/myfloat.cc | 144 ------- src/target/datatype/posit/posit-wrapper.cc | 242 ------------ src/target/datatype/registry.cc | 138 ------- src/target/datatype/registry.h | 182 --------- src/target/llvm/codegen_llvm.cc | 7 +- src/tirx/op/op.cc | 25 +- src/tirx/transform/lower_custom_datatypes.cc | 266 ------------- 19 files changed, 8 insertions(+), 1450 deletions(-) delete mode 100644 cmake/modules/contrib/Posit.cmake delete mode 100644 python/tvm/target/datatype.py delete mode 100644 src/target/datatype/myfloat/myfloat.cc delete mode 100644 src/target/datatype/posit/posit-wrapper.cc delete mode 100644 src/target/datatype/registry.cc delete mode 100644 src/target/datatype/registry.h delete mode 100644 src/tirx/transform/lower_custom_datatypes.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index ad99c4c6acba..0eb8c9018495 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -87,7 +87,6 @@ tvm_option(USE_CCACHE "Use ccache if found when invoking compiler" AUTO) # 3rdparty libraries tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt") # Contrib library options -tvm_option(USE_BYODT_POSIT "Build with BYODT software emulated posit custom datatype" OFF) tvm_option(USE_BLAS "The blas library to be linked" none) tvm_option(USE_AMX "Enable Intel AMX" OFF) tvm_option(USE_MKL "MKL root path when use MKL blas" OFF) @@ -356,10 +355,6 @@ tvm_file_glob(GLOB CODEGEN_SRCS list(APPEND COMPILER_SRCS ${CODEGEN_SRCS}) -tvm_file_glob(GLOB DATATYPE_SRCS src/target/datatype/*.cc) -list(APPEND COMPILER_SRCS ${DATATYPE_SRCS}) -list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc") - tvm_file_glob(GLOB RUNTIME_SRCS src/runtime/*.cc src/runtime/vm/*.cc @@ -464,7 +459,6 @@ include(cmake/modules/contrib/DNNL.cmake) include(cmake/modules/contrib/AMX.cmake) include(cmake/modules/contrib/CUTLASS.cmake) include(cmake/modules/contrib/Random.cmake) -include(cmake/modules/contrib/Posit.cmake) include(cmake/modules/contrib/Sort.cmake) include(cmake/modules/contrib/CoreML.cmake) include(cmake/modules/contrib/TensorRT.cmake) diff --git a/cmake/modules/contrib/Posit.cmake b/cmake/modules/contrib/Posit.cmake deleted file mode 100644 index b8d180ee4480..000000000000 --- a/cmake/modules/contrib/Posit.cmake +++ /dev/null @@ -1,26 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -if(USE_BYODT_POSIT) - message(STATUS "Build with contrib.posit") - if (NOT UNIVERSAL_PATH) - message(FATAL_ERROR "Fail to get Universal path") - endif(NOT UNIVERSAL_PATH) - - include_directories(${UNIVERSAL_PATH}/include) - list(APPEND COMPILER_SRCS "src/target/datatype/posit/posit-wrapper.cc") -endif(USE_BYODT_POSIT) diff --git a/docker/Dockerfile.ci_cpu b/docker/Dockerfile.ci_cpu index 8e31b310fe54..e823db54b263 100644 --- a/docker/Dockerfile.ci_cpu +++ b/docker/Dockerfile.ci_cpu @@ -63,10 +63,6 @@ RUN bash /install/ubuntu_install_dnnl.sh COPY install/ubuntu_install_xgboost.sh /install/ubuntu_install_xgboost.sh RUN bash /install/ubuntu_install_xgboost.sh -# BYODT deps -COPY install/ubuntu_install_universal.sh /install/ubuntu_install_universal.sh -RUN bash /install/ubuntu_install_universal.sh - # TensorFlow deps COPY install/ubuntu_install_tensorflow.sh /install/ubuntu_install_tensorflow.sh RUN bash /install/ubuntu_install_tensorflow.sh diff --git a/docker/Dockerfile.ci_gpu b/docker/Dockerfile.ci_gpu index df15215b94f1..2f9139f842cd 100644 --- a/docker/Dockerfile.ci_gpu +++ b/docker/Dockerfile.ci_gpu @@ -115,10 +115,6 @@ RUN bash /install/ubuntu_install_vulkan.sh COPY install/ubuntu_install_xgboost.sh /install/ubuntu_install_xgboost.sh RUN bash /install/ubuntu_install_xgboost.sh -# BYODT deps -COPY install/ubuntu_install_universal.sh /install/ubuntu_install_universal.sh -RUN bash /install/ubuntu_install_universal.sh - # sccache COPY install/ubuntu_install_sccache.sh /install/ubuntu_install_sccache.sh RUN bash /install/ubuntu_install_sccache.sh diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h index 202766571278..7a7584aff2da 100644 --- a/include/tvm/tirx/op.h +++ b/include/tvm/tirx/op.h @@ -998,13 +998,6 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = Span()) } if (t.is_float() || t.is_bfloat16() || t.is_float8() || t.is_float6() || t.is_float4()) return FloatImm(t, static_cast(value), span); - // For now, we store const scalar values of custom datatypes within doubles; later, during the - // datatypes lowering pass, we will lower the value to its true representation in the format - // specified by the datatype. - // TODO(gus) when do we need to start worrying about doubles not being precise enough? - if (static_cast(t.code()) >= static_cast(DataType::kCustomBegin)) { - return FloatImm(t, static_cast(value), span); - } TVM_FFI_THROW(InternalError) << "cannot make const for type " << t; throw; } diff --git a/include/tvm/tirx/transform.h b/include/tvm/tirx/transform.h index 32a3ea8b2984..e5a754f6c54f 100644 --- a/include/tvm/tirx/transform.h +++ b/include/tvm/tirx/transform.h @@ -153,15 +153,6 @@ TVM_DLL Pass MakePackedAPI(); */ TVM_DLL Pass RemapThreadAxis(ffi::Map axis_map); -/*! - * \brief Lower custom datatypes. - * - * See tvm::datatypes::Registry for more information on adding custom datatypes. - * - * \return The pass. - */ -TVM_DLL Pass LowerCustomDatatypes(); - /*! * \brief Annotate, split, and lower host/device functions. * diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py index fb8310dc2604..df1cd74a216b 100644 --- a/python/tvm/s_tir/pipeline.py +++ b/python/tvm/s_tir/pipeline.py @@ -125,7 +125,6 @@ def finalize_host_passes(): # pylint: disable=unused-argument """The default finalization passes for TIR backend.""" host_pass_list = [ tirx.transform.LowerTVMBuiltin(), - tirx.transform.LowerCustomDatatypes(), tirx.transform.LowerIntrin(), ] return tvm.ir.transform.Sequential(host_pass_list) @@ -136,7 +135,6 @@ def finalize_device_passes(): # pylint: disable=unused-argument device_pass_list = [ tirx.transform.LowerWarpMemory(), tirx.transform.StmtSimplify(), - tirx.transform.LowerCustomDatatypes(), tirx.transform.LowerIntrin(), ] return tvm.ir.transform.Sequential(device_pass_list) diff --git a/python/tvm/target/__init__.py b/python/tvm/target/__init__.py index 5c6733cf8cbd..7303a0d09745 100644 --- a/python/tvm/target/__init__.py +++ b/python/tvm/target/__init__.py @@ -34,6 +34,5 @@ from .target import Target, TargetKind from .virtual_device import VirtualDevice from .tag import list_tags, register_tag -from . import datatype from . import codegen from . import tag_registry # registers tags on import diff --git a/python/tvm/target/datatype.py b/python/tvm/target/datatype.py deleted file mode 100644 index d7a47836b2f5..000000000000 --- a/python/tvm/target/datatype.py +++ /dev/null @@ -1,379 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ruff: noqa: F821 -"""Bring Your Own Datatypes custom datatype framework - -TODO(@gussmith23 @hypercubestart) link to BYODT docs when they exist""" - -from tvm_ffi import get_global_func -from tvm_ffi import register_global_func as _register_global_func - -import tvm -from tvm.runtime import DataType, convert -from tvm.tirx import call_intrin -from tvm.tirx.expr import ( - BinaryOpExpr as _BinaryOpExpr, -) -from tvm.tirx.expr import ( - Call as _Call, -) -from tvm.tirx.expr import ( - Cast as _Cast, -) -from tvm.tirx.expr import ( - FloatImm as _FloatImm, -) -from tvm.tirx.op import call_pure_extern - - -def register(type_name, type_code): - """Register a custom datatype with the given type name and type code - - Currently, the type code is manually allocated by the user, and the user - must ensure that no two custom types share the same code. Generally, this - should be straightforward, as the user will be manually registering all of - their custom types. - - Example: - - .. code-block:: python - - # Register a dtype named 'posites2' under type code 130. - tvm.target.datatype.register('posites2', 130) - - - Parameters - ---------- - type_name : str - The name of the custom datatype. - - type_code : int - The type's code, which should be >= kCustomBegin. See - include/tvm/runtime/data_type.h. - """ - get_global_func("dtype.register_custom_type")(type_name, type_code) - - -def get_type_name(type_code): - """Get the type name of a custom datatype from the type code. - - Note that this only works for custom datatypes registered with - tvm.target.datatype.register(). It does not work for TVM-native types. - - Example: - - .. code-block:: python - - tvm.target.datatype.register('posites2', 130) - assert tvm.target.datatype.get_type_name(130) == 'posites2' - - Parameters - ---------- - type_code : int - The type code of the custom datatype. - - Returns - ------- - type_name : String - The name of the custom datatype. - - """ - return get_global_func("dtype.get_custom_type_name")(type_code) - - -def get_type_code(type_name): - """Get the type code of a custom datatype from its type name - - Note that this only works for custom datatypes registered with - tvm.target.datatype.register(). It does not work for TVM-native types. - - Example: - - .. code-block:: python - - tvm.target.datatype.register('posites2', 130) - assert tvm.target.datatype.get_type_code('posites2') == 130 - - Parameters - ---------- - type_name : str - The type name - - Returns - ------- - type_code : int - The type code of the custom datatype. - """ - return get_global_func("dtype.get_custom_type_code")(type_name) - - -def get_type_registered(type_code): - """Returns true if a custom datatype is registered under the given type code - - Example: - - .. code-block:: python - - tvm.target.datatype.register('posites2', 130) - assert tvm.target.datatype.get_type_registered(130) - - Parameters - ---------- - type_code: int - The type code - - Returns - ------- - type_registered : bool - True if a custom datatype is registered under this type code, and false - otherwise. - """ - return tvm.runtime._ffi_api._datatype_get_type_registered(type_code) - - -def register_op( - lower_func, op_name, target, src_type_name, dest_type_name=None, intrinsic_name=None -): - """Register a lowering function for a specific operator of a custom datatype - - At build time, Relay must lower operators over custom datatypes into - operators it understands how to compile. For each custom datatype operator - which Relay finds while lowering custom datatypes, Relay expects to find a - user-defined lowering function. Users register their user-defined lowering - functions using this function. - - Users should use create_lower_func to create their lowering function. It - should serve most use-cases. - - Currently, this will work with Casts, intrinsics (e.g. sqrt, sigmoid), and - binary expressions (e.g. Add, Sub, Mul, Div). - - See the LowerCustomDatatypes pass to see how registered functions are used. - - Lowering Functions - ------------------ - TODO(@gussmith23) Get the terminology right here. - Lowering functions take in a Relay node, and should return a semantically - equivalent Relay node which Relay can build. This means that the returned - node should not contain any custom datatypes. Users should likely not need - to define lowering functions by hand -- see the helper function - create_lower_func. - - Parameters - ---------- - lower_func : function - The lowering function to call. See create_lower_func. - - op_name : str - The name of the operation which the function computes, given by its - class name (e.g. Add, LE, Cast, Call). - - target : str - The name of codegen target. - - src_type_name : str - The name of the custom datatype, e.g. posites2 (but not custom[posites2]32). - If op_name is not "Cast", then target type is guaranteed to be the same as src_type_name. - - dest_type_name : str - If op_name is "Cast", then this is required and should be set to the dest datatype of - the argument to the Cast. If op_name is not "Cast", this is unused. - - intrinsic_name : str - If op_name is "Call" and intrinsic_name is not None, then we assume the - op is a Call to an Intrinsic, and intrinsic_name is the intrinsic's - name. - """ - - if op_name == "Cast": - assert dest_type_name is not None - lower_func_name = ( - "tvm.datatype.lower." - + target - + "." - + op_name - + "." - + dest_type_name - + "." - + src_type_name - ) - elif op_name == "Call" and intrinsic_name is not None: - lower_func_name = ( - "tvm.datatype.lower." - + target - + "." - + op_name - + ".intrin." - + intrinsic_name - + "." - + src_type_name - ) - else: - lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." + src_type_name - tvm_ffi.register_global_func(lower_func_name, lower_func) - - -def register_min_func(func, type_name): - """Register the function that returns the minimum representable value of type_name. - - Operators such as max pooling and argmax require the minimum - finite value representable by the datatype the op operating on. - Users can use this function to register a function that returns a TIR expression node - outputting the minimum representable value of their custom data type. - - Users should use create_min_lower_func to create their lowering function. It - should serve most use-cases. - - Note: for special cases when it is known that the custom datatype is representable - by a float, the user can create their own lowering func that returns a FloatImm. - The benefits are allowing optimizations such as rewrites to work as expected on custom - datatypes. - - Parameters - ---------- - func : function - Input is an integer num_bits, should return a TIR expression node that - represents a scalar tensor of type custom[type_name]num_bits with the minimum - representable value. - - type_name : str - The name of the custom datatype, e.g. posites2 (but not custom[posites2]32). - """ - _register_global_func("tvm.datatype.min." + type_name, func) - - -def create_min_lower_func(extern_func_map, type_name): - """Returns a lowering function for getting the minimum value of a custom datatype. - - Parameters - ---------- - extern_func_map : map - A map from bit lengths to the name of the extern "C" function to lower to. - - type_name : string - The name of the custom datatype, e.g. posites2 (but not custom[posites2]32). - """ - - def lower(num_bits): - dtype = f"custom[{type_name}]{num_bits}" - - if num_bits not in extern_func_map: - raise RuntimeError("missing minimum function for {dtype}") - - return call_pure_extern(dtype, extern_func_map[num_bits]) - - return lower - - -def create_lower_func(extern_func_map): - """Returns a function which lowers an operation to a function call. - - Parameters - ---------- - extern_func_map : map - If lowering a Cast, extern_func_map should be a map from tuples of - (src_bit_length, dest_bit_length) to the name of the extern "C" function to lower to. - - Otherwise, for unary and binary ops, it should simply be a map - from bit_length to the name of the extern "C" function to lower to. - """ - - def lower(op): - """ - Takes an op---either a Cast, Call, or a binary op (e.g. an Add) and returns a - call to the specified external function, passing the op's argument - or arguments. The return type of the call depends - on the type of the op: if it is a custom type, then a uint of the same - width as the custom type is returned. Otherwise, the type is - unchanged.""" - dtype = op.dtype - t = DataType(dtype) - if get_type_registered(t.type_code): - dtype = "uint" + str(t.bits) - if t.lanes > 1: - dtype += "x" + str(t.lanes) - - key = t.bits - if isinstance(op, _Cast): - src_bits = DataType(op.value.dtype).bits - key = (src_bits, t.bits) - - if key not in extern_func_map: - raise RuntimeError(f"missing key {key} in extern_func_map for {op}") - - if isinstance(op, _Cast): - return call_pure_extern(dtype, extern_func_map[key], op.value) - if isinstance(op, _FloatImm): - return call_pure_extern(dtype, extern_func_map[key], op.value) - if isinstance(op, _Call): - return call_pure_extern(dtype, extern_func_map[key], *op.args) - if isinstance(op, _BinaryOpExpr): - return call_pure_extern(dtype, extern_func_map[key], op.a, op.b) - - raise RuntimeError(f"lowering unsupported op: {op}") - - return lower - - -def lower_ite(ite_op): - """Lowered if then else function that calls intrinsic if_then_else. - Unlike a function lowered by create_lower_func, this function - calls the tvm intrinsic if_then_else. - - Parameters - ---------- - ite_op : Op - Takes an if then else op and returns a - call to tirx.if_then_else function, passing the op's - arguments. The return type of the call if a uint of the same - width as the custom type is returned. - """ - dtype = ite_op.dtype - t = tvm.DataType(dtype) - assert get_type_registered(t.type_code) - dtype = "uint" + str(t.bits) - if t.lanes > 1: - dtype += "x" + str(t.lanes) - return call_intrin( - dtype, - "tirx.if_then_else", - convert(ite_op.args[0]), - convert(ite_op.args[1]), - convert(ite_op.args[2]), - ) - - -def lower_call_pure_extern(op): - """Lowered call pure extern function that calls intrinsic call_pure_extern. - Unlike a function lowered by create_lower_func, this function - calls the tvm intrinsic call_pure_extern. - - Parameters - ---------- - ite_op : Op - Takes a call_pure_extern op and returns a - call to tirx.call_pure_extern function, passing the op's - arguments. The return type of the call if a uint of the same - width as the custom type is returned. - """ - dtype = op.dtype - t = tvm.DataType(dtype) - assert get_type_registered(t.type_code) - dtype = "uint" + str(t.bits) - if t.lanes > 1: - dtype += "x" + str(t.lanes) - return call_intrin(dtype, "tirx.call_pure_extern", *op.args) diff --git a/python/tvm/tirx/compilation_pipeline.py b/python/tvm/tirx/compilation_pipeline.py index d2847332b4a7..23dee416bbe2 100644 --- a/python/tvm/tirx/compilation_pipeline.py +++ b/python/tvm/tirx/compilation_pipeline.py @@ -103,7 +103,6 @@ def finalize_host_passes(): # pylint: disable=unused-argument """The default finalization passes for TIR backend.""" host_pass_list = [ tirx.transform.LowerTVMBuiltin(), - tirx.transform.LowerCustomDatatypes(), tirx.transform.LowerIntrin(), ] return tvm.ir.transform.Sequential(host_pass_list) @@ -114,7 +113,6 @@ def finalize_device_passes(): # pylint: disable=unused-argument device_pass_list = [ tirx.transform.LowerWarpMemory(), tirx.transform.StmtSimplify(), - tirx.transform.LowerCustomDatatypes(), tirx.transform.LowerIntrin(), ] return tvm.ir.transform.Sequential(device_pass_list) diff --git a/python/tvm/tirx/transform/transform.py b/python/tvm/tirx/transform/transform.py index 72a5b96202d2..ae6b942b66e2 100644 --- a/python/tvm/tirx/transform/transform.py +++ b/python/tvm/tirx/transform/transform.py @@ -245,19 +245,6 @@ def ConvertSSA(): return _ffi_api.ConvertSSA() # type: ignore -def LowerCustomDatatypes(): - """Lower custom datatypes. - - See tvm::datatypes::Registry for more information on adding custom datatypes. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.LowerCustomDatatypes() # type: ignore - - def MakePackedAPI(): """Transform the PrimFuncs in the module to a packed func API. diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index bec509188311..5a86cdd15abb 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -34,7 +34,6 @@ #include #include -#include "../target/datatype/registry.h" #include "../tirx/analysis/check_contains.h" #include "conjunctive_normal_form.h" #include "const_fold.h" diff --git a/src/target/datatype/myfloat/myfloat.cc b/src/target/datatype/myfloat/myfloat.cc deleted file mode 100644 index afee8a7c4bf0..000000000000 --- a/src/target/datatype/myfloat/myfloat.cc +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file 3rdparty/byodt/my-custom-datatype.cc - * \brief Example Custom Datatype with the Bring Your Own Datatypes (BYODT) framework. - * This is a toy example that under the hood simulates floats. - * - * Users interested in using the BYODT framework can use this file as a template. - * - * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist? - */ -#include - -#include -#include -#include - -// Custom datatypes are stored as bits in a uint of the appropriate bit length. -// Thus, when TVM calls these C functions, -// the arguments of are uints that need to reinterpreted as your custom datatype. -// -// When returning, your custom datatype needs to be re-wrapped into a uint, -// which can be thought of as just a wrapper for the raw bits that represent your custom datatype. -template -TVM_DLL T Uint32ToCustom32(uint32_t in) { - // This is a helper function to interpret the uint as your custom dataype. - // The following line should be replaced with the appropriate function - // that interprets the bits in `in` and returns your custom datatype - T* custom = reinterpret_cast(&in); - return *custom; -} - -template -TVM_DLL uint32_t Custom32ToUint32(T in) { - // This is a helper function to wrap your custom datatype in a uint. - // the following line should be replaced with the appropriate function - // that converts your custom datatype into a uint - uint32_t* bits = reinterpret_cast(&in); - return *bits; -} - -extern "C" { -TVM_DLL uint32_t MinCustom32() { - // return minimum representable value - float min = std::numeric_limits::lowest(); - return Custom32ToUint32(min); -} - -TVM_DLL float Custom32ToFloat(uint32_t in) { - // cast from custom datatype to float - float custom_datatype = Uint32ToCustom32(in); - // our custom datatype is float, so the following redundant cast to float - // is to remind users to cast their own custom datatype to float - return static_cast(custom_datatype); -} - -TVM_DLL uint32_t FloatToCustom32(float in) { - // cast from float to custom datatype - return Custom32ToUint32(in); -} - -TVM_DLL uint32_t Custom32Add(uint32_t a, uint32_t b) { - // add operation - float acustom = Uint32ToCustom32(a); - float bcustom = Uint32ToCustom32(b); - return Custom32ToUint32(acustom + bcustom); -} - -TVM_DLL uint32_t Custom32Sub(uint32_t a, uint32_t b) { - // subtract - float acustom = Uint32ToCustom32(a); - float bcustom = Uint32ToCustom32(b); - return Custom32ToUint32(acustom - bcustom); -} - -TVM_DLL uint32_t Custom32Mul(uint32_t a, uint32_t b) { - // multiply - float acustom = Uint32ToCustom32(a); - float bcustom = Uint32ToCustom32(b); - return Custom32ToUint32(acustom * bcustom); -} - -TVM_DLL uint32_t Custom32Div(uint32_t a, uint32_t b) { - // divide - float acustom = Uint32ToCustom32(a); - float bcustom = Uint32ToCustom32(b); - return Custom32ToUint32(acustom / bcustom); -} - -TVM_DLL uint32_t Custom32Max(uint32_t a, uint32_t b) { - // max - float acustom = Uint32ToCustom32(a); - float bcustom = Uint32ToCustom32(b); - return Custom32ToUint32(acustom > bcustom ? acustom : bcustom); -} - -TVM_DLL uint32_t Custom32Sqrt(uint32_t a) { - // sqrt - float acustom = Uint32ToCustom32(a); - return Custom32ToUint32(sqrt(acustom)); -} - -TVM_DLL uint32_t Custom32Exp(uint32_t a) { - // exponential - float acustom = Uint32ToCustom32(a); - return Custom32ToUint32(exp(acustom)); -} - -TVM_DLL uint32_t Custom32Log(uint32_t a) { - // log - float acustom = Uint32ToCustom32(a); - return Custom32ToUint32(log(acustom)); -} - -TVM_DLL uint32_t Custom32Sigmoid(uint32_t a) { - // sigmoid - float acustom = Uint32ToCustom32(a); - float one = 1.0f; - return Custom32ToUint32(one / (one + exp(-acustom))); -} - -TVM_DLL uint32_t Custom32Tanh(uint32_t a) { - // tanh - float acustom = Uint32ToCustom32(a); - return Custom32ToUint32(tanh(acustom)); -} -} diff --git a/src/target/datatype/posit/posit-wrapper.cc b/src/target/datatype/posit/posit-wrapper.cc deleted file mode 100644 index e05695e60366..000000000000 --- a/src/target/datatype/posit/posit-wrapper.cc +++ /dev/null @@ -1,242 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file 3rdparty/posit/posit-wrapper.cc - * \brief Wrapper over the Stillwater Universal library for Bring Your Own Datatypes tests - * - * To compile TVM with this file, - * 1. clone the Stillwater Universal repo from here `https://github.com/stillwater-sc/universal`. - * 2. set `SET_BYODT_POSIT` ON and `UNIVERSAL_PATH` as the path to the folder containing Stillwater - * Universal in your CMake file - * - * TODO(@gussmith23 @hypercubestart) Link to BYODT docs when they exist? - */ -#include - -#include - -#include "universal/posit/posit.hpp" -// must go after posit.hpp -#include "universal/posit/math/exponent.hpp" -#include "universal/posit/math/hyperbolic.hpp" -#include "universal/posit/math/logarithm.hpp" -#include "universal/posit/math/sqrt.hpp" -#include "universal/posit/numeric_limits.hpp" - -TVM_DLL sw::unum::posit<8, 2> Uint8ToPosit8es2(uint8_t in) { - sw::unum::bitblock<8> bb; - bb = static_cast(in); - return sw::unum::posit<8, 2>().set(bb); -} - -extern "C" { -TVM_DLL uint8_t Posit8es2toUint8(sw::unum::posit<8, 2> in) { - return static_cast(in.get().to_ullong()); -} - -TVM_DLL uint8_t MinPosit8es2() { - auto min = std::numeric_limits>::lowest(); - return Posit8es2toUint8(min); -} - -TVM_DLL float Posit8es2ToFloat(uint8_t in) { return Uint8ToPosit8es2(in).operator float(); } - -TVM_DLL uint8_t FloatToPosit8es2(float in) { - auto posit = sw::unum::posit<8, 2>(in); - return Posit8es2toUint8(posit); -} - -TVM_DLL uint8_t Posit8es2Add(uint8_t a, uint8_t b) { - return Posit8es2toUint8(Uint8ToPosit8es2(a) + Uint8ToPosit8es2(b)); -} - -TVM_DLL uint8_t Posit8es2Sub(uint8_t a, uint8_t b) { - return Posit8es2toUint8(Uint8ToPosit8es2(a) - Uint8ToPosit8es2(b)); -} - -TVM_DLL uint8_t Posit8es2Mul(uint8_t a, uint8_t b) { - return Posit8es2toUint8(Uint8ToPosit8es2(a) * Uint8ToPosit8es2(b)); -} - -TVM_DLL uint8_t Posit8es2Div(uint8_t a, uint8_t b) { - return Posit8es2toUint8(Uint8ToPosit8es2(a) / Uint8ToPosit8es2(b)); -} - -TVM_DLL uint8_t Posit8es2Max(uint8_t a, uint8_t b) { - auto a_p = Uint8ToPosit8es2(a); - auto b_p = Uint8ToPosit8es2(b); - return Posit8es2toUint8(a_p > b_p ? a_p : b_p); -} - -TVM_DLL uint8_t Posit8es2Sqrt(uint8_t a) { - return Posit8es2toUint8(sw::unum::sqrt(Uint8ToPosit8es2(a))); -} - -TVM_DLL uint8_t Posit8es2Exp(uint8_t a) { - return Posit8es2toUint8(sw::unum::exp(Uint8ToPosit8es2(a))); -} - -TVM_DLL uint8_t Posit8es2Log(uint8_t a) { - return Posit8es2toUint8(sw::unum::log(Uint8ToPosit8es2(a))); -} - -TVM_DLL uint8_t Posit8es2Sigmoid(uint8_t a) { - auto posit_one = sw::unum::posit<8, 2>(1); - return Posit8es2toUint8(posit_one / (sw::unum::exp(-Uint8ToPosit8es2(a)) + posit_one)); -} - -TVM_DLL uint8_t Posit8es2Tanh(uint8_t a) { - return Posit8es2toUint8(sw::unum::tanh(Uint8ToPosit8es2(a))); -} -} - -TVM_DLL sw::unum::posit<16, 2> Uint16ToPosit16es2(uint16_t in) { - sw::unum::bitblock<16> bb; - bb = static_cast(in); - return sw::unum::posit<16, 2>().set(bb); -} - -extern "C" { -TVM_DLL uint16_t Posit16es2toUint16(sw::unum::posit<16, 2> in) { - return static_cast(in.get().to_ullong()); -} - -TVM_DLL uint8_t MinPosit16es2() { - auto min = std::numeric_limits>::lowest(); - return Posit16es2toUint16(min); -} - -TVM_DLL float Posit16es2ToFloat(uint16_t in) { return Uint16ToPosit16es2(in).operator float(); } - -TVM_DLL uint16_t FloatToPosit16es2(float in) { - auto posit = sw::unum::posit<16, 2>(in); - return Posit16es2toUint16(posit); -} - -TVM_DLL uint16_t Posit16es2Add(uint16_t a, uint16_t b) { - return Posit16es2toUint16(Uint16ToPosit16es2(a) + Uint16ToPosit16es2(b)); -} - -TVM_DLL uint16_t Posit16es2Sub(uint16_t a, uint16_t b) { - return Posit16es2toUint16(Uint16ToPosit16es2(a) - Uint16ToPosit16es2(b)); -} - -TVM_DLL uint16_t Posit16es2Mul(uint16_t a, uint16_t b) { - return Posit16es2toUint16(Uint16ToPosit16es2(a) * Uint16ToPosit16es2(b)); -} - -TVM_DLL uint16_t Posit16es2Div(uint16_t a, uint16_t b) { - return Posit16es2toUint16(Uint16ToPosit16es2(a) / Uint16ToPosit16es2(b)); -} - -TVM_DLL uint16_t Posit16es2Max(uint16_t a, uint16_t b) { - auto a_p = Uint16ToPosit16es2(a); - auto b_p = Uint16ToPosit16es2(b); - return Posit16es2toUint16(a_p > b_p ? a_p : b_p); -} - -TVM_DLL uint16_t Posit16es2Sqrt(uint16_t a) { - return Posit16es2toUint16(sw::unum::sqrt(Uint16ToPosit16es2(a))); -} - -TVM_DLL uint16_t Posit16es2Exp(uint16_t a) { - return Posit16es2toUint16(sw::unum::exp(Uint16ToPosit16es2(a))); -} - -TVM_DLL uint16_t Posit16es2Log(uint16_t a) { - return Posit16es2toUint16(sw::unum::log(Uint16ToPosit16es2(a))); -} - -TVM_DLL uint16_t Posit16es2Sigmoid(uint16_t a) { - auto posit_one = sw::unum::posit<16, 2>(1); - return Posit16es2toUint16(posit_one / (sw::unum::exp(-Uint16ToPosit16es2(a)) + posit_one)); -} - -TVM_DLL uint16_t Posit16es2Tanh(uint16_t a) { - return Posit16es2toUint16(sw::unum::tanh(Uint16ToPosit16es2(a))); -} -} - -TVM_DLL sw::unum::posit<32, 2> Uint32ToPosit32es2(uint32_t in) { - sw::unum::bitblock<32> bb; - bb = static_cast(in); - return sw::unum::posit<32, 2>().set(bb); -} - -extern "C" { -TVM_DLL uint32_t Posit32es2ToUint32(sw::unum::posit<32, 2> in) { - return static_cast(in.get().to_ullong()); -} - -TVM_DLL uint8_t MinPosit32es2() { - auto min = std::numeric_limits>::lowest(); - return Posit32es2ToUint32(min); -} - -TVM_DLL float Posit32es2ToFloat(uint32_t in) { return Uint32ToPosit32es2(in).operator float(); } - -TVM_DLL uint32_t FloatToPosit32es2(float in) { - auto posit = sw::unum::posit<32, 2>(in); - return Posit32es2ToUint32(posit); -} - -TVM_DLL uint32_t Posit32es2Add(uint32_t a, uint32_t b) { - return Posit32es2ToUint32(Uint32ToPosit32es2(a) + Uint32ToPosit32es2(b)); -} - -TVM_DLL uint32_t Posit32es2Sub(uint32_t a, uint32_t b) { - return Posit32es2ToUint32(Uint32ToPosit32es2(a) - Uint32ToPosit32es2(b)); -} - -TVM_DLL uint32_t Posit32es2Mul(uint32_t a, uint32_t b) { - return Posit32es2ToUint32(Uint32ToPosit32es2(a) * Uint32ToPosit32es2(b)); -} - -TVM_DLL uint32_t Posit32es2Div(uint32_t a, uint32_t b) { - return Posit32es2ToUint32(Uint32ToPosit32es2(a) / Uint32ToPosit32es2(b)); -} - -TVM_DLL uint32_t Posit32es2Max(uint32_t a, uint32_t b) { - auto a_p = Uint32ToPosit32es2(a); - auto b_p = Uint32ToPosit32es2(b); - return Posit32es2ToUint32(a_p > b_p ? a_p : b_p); -} - -TVM_DLL uint32_t Posit32es2Sqrt(uint32_t a) { - return Posit32es2ToUint32(sw::unum::sqrt(Uint32ToPosit32es2(a))); -} - -TVM_DLL uint32_t Posit32es2Exp(uint32_t a) { - return Posit32es2ToUint32(sw::unum::exp(Uint32ToPosit32es2(a))); -} - -TVM_DLL uint32_t Posit32es2Log(uint32_t a) { - return Posit32es2ToUint32(sw::unum::log(Uint32ToPosit32es2(a))); -} - -TVM_DLL uint32_t Posit32es2Sigmoid(uint32_t a) { - auto posit_one = sw::unum::posit<32, 2>(1); - return Posit32es2ToUint32(posit_one / (posit_one + sw::unum::exp(-Uint32ToPosit32es2(a)))); -} - -TVM_DLL uint32_t Posit32es2Tanh(uint32_t a) { - return Posit32es2ToUint32(sw::unum::tanh(Uint32ToPosit32es2(a))); -} -} diff --git a/src/target/datatype/registry.cc b/src/target/datatype/registry.cc deleted file mode 100644 index 9d6459df6cce..000000000000 --- a/src/target/datatype/registry.cc +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include "registry.h" - -#include -#include -#include - -namespace tvm { -namespace datatype { - -using ffi::Any; -using ffi::PackedArgs; - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def_packed("dtype.register_custom_type", - [](ffi::PackedArgs args, ffi::Any* ret) { - datatype::Registry::Global()->Register( - args[0].cast(), static_cast(args[1].cast())); - }) - .def_packed("dtype.get_custom_type_code", - [](ffi::PackedArgs args, ffi::Any* ret) { - *ret = datatype::Registry::Global()->GetTypeCode(args[0].cast()); - }) - .def_packed("dtype.get_custom_type_name", - [](ffi::PackedArgs args, ffi::Any* ret) { - *ret = Registry::Global()->GetTypeName(args[0].cast()); - }) - .def_packed("runtime._datatype_get_type_registered", [](ffi::PackedArgs args, ffi::Any* ret) { - *ret = Registry::Global()->GetTypeRegistered(args[0].cast()); - }); -} - -Registry* Registry::Global() { - static Registry inst; - return &inst; -} - -void Registry::Register(const std::string& type_name, uint8_t type_code) { - TVM_FFI_ICHECK(type_code >= DataType::kCustomBegin) - << "Please choose a type code >= DataType::kCustomBegin for custom types"; - code_to_name_[type_code] = type_name; - name_to_code_[type_name] = type_code; -} - -uint8_t Registry::GetTypeCode(const std::string& type_name) { - TVM_FFI_ICHECK(name_to_code_.find(type_name) != name_to_code_.end()) - << "Type name " << type_name << " not registered"; - return name_to_code_[type_name]; -} - -std::string Registry::GetTypeName(uint8_t type_code) { - TVM_FFI_ICHECK(code_to_name_.find(type_code) != code_to_name_.end()) - << "Type code " << static_cast(type_code) << " not registered"; - return code_to_name_[type_code]; -} - -std::optional GetCastLowerFunc(const std::string& target, uint8_t type_code, - uint8_t src_type_code) { - std::ostringstream ss; - ss << "tvm.datatype.lower."; - ss << target << "."; - ss << "Cast" - << "."; - - if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { - ss << datatype::Registry::Global()->GetTypeName(type_code); - } else { - ss << ffi::details::DLDataTypeCodeAsCStr(static_cast(type_code)); - } - - ss << "."; - - if (datatype::Registry::Global()->GetTypeRegistered(src_type_code)) { - ss << datatype::Registry::Global()->GetTypeName(src_type_code); - } else { - ss << ffi::details::DLDataTypeCodeAsCStr(static_cast(src_type_code)); - } - return tvm::ffi::Function::GetGlobal(ss.str()); -} - -std::optional GetMinFunc(uint8_t type_code) { - std::ostringstream ss; - ss << "tvm.datatype.min."; - ss << datatype::Registry::Global()->GetTypeName(type_code); - return tvm::ffi::Function::GetGlobal(ss.str()); -} - -std::optional GetFloatImmLowerFunc(const std::string& target, - uint8_t type_code) { - std::ostringstream ss; - ss << "tvm.datatype.lower."; - ss << target; - ss << ".FloatImm."; - ss << datatype::Registry::Global()->GetTypeName(type_code); - return tvm::ffi::Function::GetGlobal(ss.str()); -} - -std::optional GetIntrinLowerFunc(const std::string& target, - const std::string& name, uint8_t type_code) { - std::ostringstream ss; - ss << "tvm.datatype.lower."; - ss << target; - ss << ".Call.intrin."; - ss << name; - ss << "."; - ss << datatype::Registry::Global()->GetTypeName(type_code); - return tvm::ffi::Function::GetGlobal(ss.str()); -} - -uint64_t ConvertConstScalar(uint8_t type_code, double value) { - std::ostringstream ss; - ss << "tvm.datatype.convertconstscalar.float."; - ss << datatype::Registry::Global()->GetTypeName(type_code); - auto make_const_scalar_func = tvm::ffi::Function::GetGlobal(ss.str()); - return (*make_const_scalar_func)(value).cast(); -} - -} // namespace datatype -} // namespace tvm diff --git a/src/target/datatype/registry.h b/src/target/datatype/registry.h deleted file mode 100644 index 363494e0fda4..000000000000 --- a/src/target/datatype/registry.h +++ /dev/null @@ -1,182 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#ifndef TVM_TARGET_DATATYPE_REGISTRY_H_ -#define TVM_TARGET_DATATYPE_REGISTRY_H_ - -#include - -#include -#include - -namespace tvm { -namespace datatype { - -/*! - * \brief Registry for custom datatypes. - * - * Adding custom datatypes currently requires two steps: - * 1. Register the datatype with the registry via a call to - * datatype::Registry::Register. This can also be done in Python - * directly---see the TVM globals registered in the corresponding .cc file. - * Currently, user should manually choose a type name and a type code, - * ensuring that neither conflict with existing types. - * 2. Register the lowering functions needed to - * lower the custom datatype. In general, these will look like: - * For Casts: tvm.datatype.lower..Cast.. - * Example: tvm.datatype.lower.llvm.Cast.myfloat.float for a Cast from - * float to myfloat. - * For intrinsic Calls: tvm.datatype.lower..Call.intrin.. - * Example: tvm.datatype.lower.llvm.Call.intrin.sqrt.myfloat - * For other ops: tvm.datatype.lower... - * Examples: tvm.datatype.lower.llvm.Add.myfloat - * tvm.datatype.lower.llvm.FloatImm.posit - */ -class Registry { - public: - /*! - * \brief Get the global custom datatype registry singleton - */ - static Registry* Global(); - - /*! - * \brief Register custom datatype - * Register a custom datatype with the given type name and type code. Currently, the type code is - * manually allocated by the user, and the user must ensure that no two custom types share the - * same code. Generally, this should be straightforward, as the user will be manually registering - * all of their custom types. - * \param type_name The name of the type, e.g. "posites2" - * \param type_code The type code, which should be greater than TVMArgTypeCode::kTVMExtEnd - */ - void Register(const std::string& type_name, uint8_t type_code); - - /*! - * \brief Get type code from type name - * \param type_name The type name - * \return The type code - */ - uint8_t GetTypeCode(const std::string& type_name); - - /*! - * \brief Get type name from type code - * \param type_code The type code - * \return The type name - */ - std::string GetTypeName(uint8_t type_code); - - /*! - * \brief Get bool representing whether type is registered, given the type code - * \param type_code The type code - * \return bool representing whether the type is registered - */ - inline bool GetTypeRegistered(uint8_t type_code) { - return code_to_name_.find(type_code) != code_to_name_.end(); - } - - /*! - * \brief Get bool representing whether type is registered, given the type name - * \param type_name The type name - * \return bool representing whether the type is registered - */ - inline bool GetTypeRegistered(std::string type_name) { - return name_to_code_.find(type_name) != name_to_code_.end(); - } - - private: - // TODO(gus) is there a typedef for the code? - std::unordered_map code_to_name_; - std::unordered_map name_to_code_; -}; - -/*! - * \brief Convert scalar value to a custom datatype format - * \param type_code The custom datatype to convert to, specified by type code - * \param value The floating point value to convert - * \return The value, encoded in the bits of a uint64_t - */ -uint64_t ConvertConstScalar(uint8_t type_code, double value); - -/*! - * \brief Get a function returning the minimum value for a datatype. - * \param type_code The datatype - * \return Function which takes the width of the datatype and returns the min value - */ -std::optional GetMinFunc(uint8_t type_code); - -/*! - * \brief Get lowering function for Cast ops - * \param target The target we are lowering to, e.g. "llvm" - * \param type_code The datatype being cast to - * \param src_type_code The datatype being cast from - * \return Lowering function for Cast ops for the provided target, type, and source type - */ -std::optional GetCastLowerFunc(const std::string& target, uint8_t type_code, - uint8_t src_type_code); - -/*! - * \brief Get lowering function for FloatImms - * \param target The target we are lowering to, e.g. "llvm" - * \param type_code The datatype of the FloatImm - * \return Lowering function for FloatImms for the provided target and type - */ -std::optional GetFloatImmLowerFunc(const std::string& target, - uint8_t type_code); - -/*! - * \brief Get lowering function for intrinsic Calls/pure intrinsic Calls - * \param target The target we are lowering to, e.g. "llvm" - * \param type_code The datatype of the Call - * \param name The intrinsic name - * \return Lowering function for intrinsic Calls for the provided target and type - */ -std::optional GetIntrinLowerFunc(const std::string& target, - const std::string& name, uint8_t type_code); - -/*! - * \brief Get lowering function for other ops - * \param target The target we are lowering to, e.g. "llvm" - * \param type_code The datatype of the op - * \return Lowering function for other ops for the provided target and type - */ -#define DEFINE_GET_LOWER_FUNC_(OP) \ - inline std::optional Get##OP##LowerFunc(const std::string& target, \ - uint8_t type_code) { \ - return tvm::ffi::Function::GetGlobal("tvm.datatype.lower." + target + "." #OP "." + \ - datatype::Registry::Global()->GetTypeName(type_code)); \ - } - -DEFINE_GET_LOWER_FUNC_(Add) -DEFINE_GET_LOWER_FUNC_(Sub) -DEFINE_GET_LOWER_FUNC_(Mul) -DEFINE_GET_LOWER_FUNC_(Div) -DEFINE_GET_LOWER_FUNC_(Mod) -DEFINE_GET_LOWER_FUNC_(Min) -DEFINE_GET_LOWER_FUNC_(Max) -DEFINE_GET_LOWER_FUNC_(EQ) -DEFINE_GET_LOWER_FUNC_(NE) -DEFINE_GET_LOWER_FUNC_(LT) -DEFINE_GET_LOWER_FUNC_(LE) -DEFINE_GET_LOWER_FUNC_(GT) -DEFINE_GET_LOWER_FUNC_(GE) -// Later changes may need to add more lowering functions as we support workloads with more ops. - -} // namespace datatype -} // namespace tvm - -#endif // TVM_TARGET_DATATYPE_REGISTRY_H_ diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 97422bf9edfa..88a28ebccb5f 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -597,11 +597,10 @@ llvm::Type* CodeGenLLVM::GetLLVMType(const Type& type) const { if (auto* ptr = type.as()) { return DTypeToLLVMType(ptr->dtype); } else if (auto* ptr = type.as()) { - // LLVM IR doesn't allow void*, nor do we require custom datatypes - // to have LLVM equivalents, so we need to recognize these - // patterns explicitly. + // LLVM IR doesn't allow void*, so pointer element types that do not + // have an LLVM scalar equivalent need explicit handling. if (auto* primtype = ptr->element_type.as()) { - if (primtype->dtype.is_void() || primtype->dtype.code() >= DataType::kCustomBegin) { + if (primtype->dtype.is_void()) { return t_void_p_; } } else if (ptr->element_type->IsInstance()) { diff --git a/src/tirx/op/op.cc b/src/tirx/op/op.cc index 64f7f575d6a1..5cf896e4fd60 100644 --- a/src/tirx/op/op.cc +++ b/src/tirx/op/op.cc @@ -35,7 +35,6 @@ #include // Centralized header for constant folders. #include "../../arith/const_fold.h" -#include "../../target/datatype/registry.h" #include "../analysis/check_contains.h" namespace tvm { @@ -211,22 +210,16 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { // NOLINT(*) } else { rhs = cast(ltype, rhs); } - } else if (!ltype.is_float() && - (rtype.is_float() || datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) { + } else if (!ltype.is_float() && rtype.is_float()) { // Cast int->float when the other operand is a float lhs = cast(rtype, lhs); - } else if ((ltype.is_float() || datatype::Registry::Global()->GetTypeRegistered(ltype.code())) && - !rtype.is_float()) { + } else if (ltype.is_float() && !rtype.is_float()) { // Cast int->float when the other operand is a float rhs = cast(ltype, rhs); - } else if (!ltype.is_bfloat16() && - (rtype.is_bfloat16() || - datatype::Registry::Global()->GetTypeRegistered(rtype.code()))) { + } else if (!ltype.is_bfloat16() && rtype.is_bfloat16()) { // Cast int->bfloat16 when the other operand is a bfloat16 lhs = cast(rtype, lhs); - } else if ((ltype.is_bfloat16() || - datatype::Registry::Global()->GetTypeRegistered(ltype.code())) && - !rtype.is_bfloat16()) { + } else if (ltype.is_bfloat16() && !rtype.is_bfloat16()) { // Cast int->bfloat16 when the other operand is a bfloat16 rhs = cast(ltype, rhs); } else if (!ltype.is_float8() && rtype.is_float8()) { @@ -369,15 +362,7 @@ PrimExpr max_value(const DataType& dtype, Span span) { PrimExpr min_value(const DataType& dtype, Span span) { using namespace tirx; TVM_FFI_ICHECK_EQ(dtype.lanes(), 1); - if (datatype::Registry::Global()->GetTypeRegistered(dtype.code())) { - // TODO(tkonolige): need to convert all registered min functions to use the span. - auto f = datatype::GetMinFunc(dtype.code()); - TVM_FFI_ICHECK(f) << "No minimum function registered for custom dtype " - << (unsigned int)dtype.code(); - // TODO(@hypercubestart) Document this change (and others associated with the overflowing - // floatimm min bug) - return (*f)(dtype.bits()).cast(); - } else if (dtype.is_int()) { + if (dtype.is_int()) { if (dtype.bits() == 64) { return IntImm(dtype, std::numeric_limits::lowest(), span); } else if (dtype.bits() < 64) { diff --git a/src/tirx/transform/lower_custom_datatypes.cc b/src/tirx/transform/lower_custom_datatypes.cc deleted file mode 100644 index d23cfef4fbf5..000000000000 --- a/src/tirx/transform/lower_custom_datatypes.cc +++ /dev/null @@ -1,266 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -/*! - * \file tvm/src/pass/lower_custom_datatypes.cc - * \brief Pass for lowering custom datatypes - */ - -#include -#include -#include -#include -#include -#include -#include - -#include "../../target/datatype/registry.h" - -namespace tvm { -namespace tirx { - -/*! - * \brief Helper mutator to implement lowering of custom datatypes. - * - * Lowering datatypes works as follows: for every expression containing a custom - * datatype, we search for a global (registered by the implementer of the custom - * datatype) for lowering this type of expression, and uses it to lower the - * expression. - */ -class CustomDatatypesLowerer : public StmtExprMutator { - public: - explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {} - - PrimExpr VisitExpr_(const CastNode* op) final { - auto type_code = op->dtype.code(); - auto src_type_code = op->value.dtype().code(); - // If either datatype is a registered custom datatype, we must lower. - bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(type_code) || - datatype::Registry::Global()->GetTypeRegistered(src_type_code); - PrimExpr expr = StmtExprMutator::VisitExpr_(op); - if (to_be_lowered) { - auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code); - TVM_FFI_ICHECK(lower) << "Cast lowering function for target " << target_ - << " destination type " << static_cast(type_code) - << " source type " << static_cast(src_type_code) - << " not found"; - return (*lower)(expr).cast(); - } - return expr; - } - - PrimExpr VisitExpr_(const FloatImmNode* imm) final { - auto type_code = imm->dtype.code(); - auto e = ffi::GetRef(imm); - if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { - auto lower = datatype::GetFloatImmLowerFunc(target_, type_code); - TVM_FFI_ICHECK(lower) << "FloatImm lowering function for target " << target_ << " type " - << static_cast(type_code) << " not found"; - return (*lower)(e).cast(); - } - return e; - } - - PrimExpr VisitExpr_(const VarNode* op) final { - Var var = ffi::GetRef(op); - - auto itr = var_remap_.find(var); - if (itr != var_remap_.end()) { - return itr->second; - } else { - return var; - } - } - - Stmt VisitStmt_(const AllocBufferNode* op) final { - bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(op->buffer->dtype.code()); - - if (to_be_lowered) { - auto new_allocate_type = DataType::UInt(op->buffer->dtype.bits(), op->buffer->dtype.lanes()); - auto new_buffer_var = - Var(op->buffer->data->name_hint, PointerType(PrimType(new_allocate_type))); - var_remap_[op->buffer->data] = new_buffer_var; - } - Stmt stmt = StmtExprMutator::VisitStmt_(op); - op = stmt.as(); - - Buffer new_buf = GetRemappedBuffer(op->buffer); - if (!new_buf.same_as(op->buffer)) { - auto node = Downcast(stmt); - node.CopyOnWrite()->buffer = new_buf; - return node; - } - return stmt; - } - - Stmt VisitStmt_(const DeclBufferNode* op) final { - auto node = Downcast(StmtExprMutator::VisitStmt_(op)); - return VisitBufferAccess(std::move(node)); - } - - PrimExpr VisitExpr_(const BufferLoadNode* op) final { - auto node = Downcast(StmtExprMutator::VisitExpr_(op)); - auto modified = VisitBufferAccess(node); - - // Not needed for BufferStoreNode, so we can't just call - // LegalizeDtype() in VisitBufferAccess. - if (node.same_as(modified)) { - return node; - - } else { - auto writer = modified.CopyOnWrite(); - writer->LegalizeDType(); - return modified; - } - } - - Stmt VisitStmt_(const BufferStoreNode* op) final { - auto node = Downcast(StmtExprMutator::VisitStmt_(op)); - return VisitBufferAccess(std::move(node)); - } - - template - Node VisitBufferAccess(Node node) { - Buffer new_buf = GetRemappedBuffer(node->buffer); - if (!new_buf.same_as(node->buffer)) { - auto writer = node.CopyOnWrite(); - writer->buffer = new_buf; - } - - return node; - } - - Buffer GetRemappedBuffer(Buffer buf) { - auto key = buf; - auto cache_it = buf_remap_.find(key); - if (cache_it != buf_remap_.end()) { - return cache_it->second; - } - - bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(buf->dtype.code()); - - if (to_be_lowered) { - auto new_load_type = DataType::UInt(buf->dtype.bits()); - auto writer = buf.CopyOnWrite(); - writer->dtype = new_load_type; - - auto var_it = var_remap_.find(buf->data); - if (var_it != var_remap_.end()) { - writer->data = var_it->second; - } - } - - buf_remap_[key] = buf; - return buf; - } - - Stmt VisitStmt_(const AttrStmtNode* op) final { - Stmt ret = StmtExprMutator::VisitStmt_(op); - op = ret.as(); - // Due to legacy reasons, some attr node can contain - // information(e.g. alignment) of buffer variables. - // remap these vars when needed - // TODO(tvm-team): remove the rewriting once the buffer var - // attrs are being refactored into the corresponding definition node - if (auto var_node = op->node.as()) { - auto it = var_remap_.find(var_node.value()); - if (it != var_remap_.end()) { - return AttrStmt(it->second, op->attr_key, op->value, op->body); - } - } - return ret; - } - - PrimExpr VisitExpr_(const CallNode* call) final { - bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(call->dtype.code()); - PrimExpr expr = StmtExprMutator::VisitExpr_(call); - call = expr.as(); - if (to_be_lowered) { - auto op = call->op.as(); - TVM_FFI_ICHECK(op != nullptr) << "Lowering non-intrinsic Calls not implemented"; - auto lower = datatype::GetIntrinLowerFunc(target_, op->name, call->dtype.code()); - TVM_FFI_ICHECK(lower) << "Intrinsic lowering function for target " << target_ - << ", intrinsic name " << op->name << ", type " - << static_cast(call->dtype.code()) << " not found"; - return (*lower)(expr).cast(); - } - return expr; - } - -#define TVM_DEFINE_MUTATE_CUSTOM_DTYPE(OP, NodeName) \ - PrimExpr VisitExpr_(const NodeName* op) final { \ - auto type_code = op->dtype.code(); \ - bool to_be_lowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ - PrimExpr expr = StmtExprMutator::VisitExpr_(op); \ - op = expr.as(); \ - if (to_be_lowered) { \ - auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ - TVM_FFI_ICHECK(lower) << #OP " lowering function for target " << target_ << " type " \ - << static_cast(type_code) << " not found"; \ - return (*lower)(expr).cast(); \ - } \ - return expr; \ - } - - TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Add, AddNode); - TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Sub, SubNode); - TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mul, MulNode); - TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Div, DivNode); - TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Mod, ModNode); - TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Min, MinNode); - TVM_DEFINE_MUTATE_CUSTOM_DTYPE(Max, MaxNode); - TVM_DEFINE_MUTATE_CUSTOM_DTYPE(EQ, EQNode); - TVM_DEFINE_MUTATE_CUSTOM_DTYPE(NE, NENode); - TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LT, LTNode); - TVM_DEFINE_MUTATE_CUSTOM_DTYPE(LE, LENode); - TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GT, GTNode); - TVM_DEFINE_MUTATE_CUSTOM_DTYPE(GE, GENode); - // Later changes may need to add more mutate functions as we support workloads with more ops. - -#undef TVM_DEFINE_MUTATE_CUSTOM_DTYPE - - private: - std::string target_; - // remap buffer vars - std::unordered_map var_remap_; - std::unordered_map buf_remap_; -}; - -namespace transform { - -Pass LowerCustomDatatypes() { - auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { - auto* n = f.CopyOnWrite(); - auto target = f->GetAttr(tvm::attr::kTarget); - TVM_FFI_ICHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute"; - - n->body = CustomDatatypesLowerer(target.value()->kind->name)(std::move(n->body)); - return f; - }; - return CreatePrimFuncPass(pass_func, 0, "tirx.LowerCustomDatatypes", {}); -} - -TVM_FFI_STATIC_INIT_BLOCK() { - namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("tirx.transform.LowerCustomDatatypes", LowerCustomDatatypes); -} - -} // namespace transform - -} // namespace tirx -} // namespace tvm