From e75e11a12dd123fda266b488959c7c51b66efef8 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Thu, 11 Jun 2026 12:25:33 -0400 Subject: [PATCH 01/21] [CMAKE] Upgrade TVM build baseline to C++20 Move TVM's core C++ and CUDA dialects to C++20 and update first-party helper build paths that compile TVM-generated C++ code. This keeps the codebase ready for C++20 native dependencies while leaving vendored third-party projects to manage their own minimum standards. --- CMakeLists.txt | 4 ++-- apps/android_rpc/app/src/main/jni/Application.mk | 2 +- apps/hexagon_api/CMakeLists.txt | 6 +++--- apps/hexagon_launcher/CMakeLists.txt | 4 ++-- apps/hexagon_launcher/cmake/android/CMakeLists.txt | 2 +- apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt | 2 +- cmake/utils/FindLLVM.cmake | 4 ++-- docs/install/from_source.rst | 13 ++++++++----- jvm/native/linux-x86_64/pom.xml | 2 +- jvm/native/osx-x86_64/pom.xml | 2 +- python/tvm/contrib/cutlass/build.py | 2 +- python/tvm/contrib/cutlass/gen_tensor_op.py | 2 +- python/tvm/relax/frontend/nn/extern.py | 6 +++--- python/tvm/rpc/minrpc.py | 2 +- python/tvm/support/emcc.py | 2 +- .../python/relax/test_frontend_nn_extern_module.py | 2 +- web/Makefile | 2 +- 17 files changed, 31 insertions(+), 28 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e25f10e7f13..00632f8d71d9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -439,9 +439,9 @@ include(cmake/utils/CCache.cmake) include(CheckCXXCompilerFlag) set(CMAKE_CXX_STANDARD_REQUIRED ON) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CUDA_STANDARD_REQUIRED ON) -set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD 20) # Module rules include(cmake/modules/CUDA.cmake) diff --git a/apps/android_rpc/app/src/main/jni/Application.mk b/apps/android_rpc/app/src/main/jni/Application.mk index a7996548eb4d..bc410416ce9f 100644 --- a/apps/android_rpc/app/src/main/jni/Application.mk +++ b/apps/android_rpc/app/src/main/jni/Application.mk @@ -31,7 +31,7 @@ include $(config) APP_ABI ?= armeabi-v7a arm64-v8a x86 x86_64 mips APP_STL := c++_shared -APP_CPPFLAGS += -DTVM4J_ANDROID=1 -std=c++17 -Oz -frtti +APP_CPPFLAGS += -DTVM4J_ANDROID=1 -std=c++20 -Oz -frtti ifeq ($(USE_OPENCL), 1) APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 endif diff --git a/apps/hexagon_api/CMakeLists.txt b/apps/hexagon_api/CMakeLists.txt index 62dca9d4e644..fc2be3a61a5d 100644 --- a/apps/hexagon_api/CMakeLists.txt +++ b/apps/hexagon_api/CMakeLists.txt @@ -45,7 +45,7 @@ ExternalProject_Add(x86_tvm_runtime_rpc "-DCMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER}" "-DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER}" "-DUSE_HEXAGON_TOOLCHAIN=${USE_HEXAGON_TOOLCHAIN}" - "-DCMAKE_CXX_STANDARD=17" + "-DCMAKE_CXX_STANDARD=20" "-DTVM_FFI_USE_LIBBACKTRACE=OFF" "-DTVM_FFI_USE_THREADS=OFF" "-DTVM_FFI_USE_DL_LIBS=OFF" @@ -81,7 +81,7 @@ ExternalProject_Add(android_tvm_runtime_rpc "-DANDROID_ABI=${ANDROID_ABI}" "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" - "-DCMAKE_CXX_STANDARD=17" + "-DCMAKE_CXX_STANDARD=20" "-DTVM_FFI_USE_LIBBACKTRACE=OFF" "-DTVM_FFI_USE_THREADS=OFF" "-DTVM_FFI_USE_DL_LIBS=OFF" @@ -135,7 +135,7 @@ ExternalProject_Add(hexagon_tvm_runtime_rpc "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" "-DUSE_HEXAGON_EXTERNAL_LIBS=${USE_HEXAGON_EXTERNAL_LIBS}" "-DHEXAGON_EXTERNAL_LIBS_SHA=${HEXAGON_EXTERNAL_LIBS_SHA}" - "-DCMAKE_CXX_STANDARD=17" + "-DCMAKE_CXX_STANDARD=20" "-DTVM_FFI_USE_LIBBACKTRACE=OFF" "-DTVM_FFI_USE_THREADS=OFF" "-DTVM_FFI_USE_DL_LIBS=OFF" diff --git a/apps/hexagon_launcher/CMakeLists.txt b/apps/hexagon_launcher/CMakeLists.txt index c08e743a2592..b42bdc324bd6 100644 --- a/apps/hexagon_launcher/CMakeLists.txt +++ b/apps/hexagon_launcher/CMakeLists.txt @@ -44,7 +44,7 @@ ExternalProject_Add(android_launcher_binaries "-DCMAKE_TOOLCHAIN_FILE=${USE_ANDROID_TOOLCHAIN}" "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" "-DANDROID_ABI=${ANDROID_ABI}" - "-DCMAKE_CXX_STANDARD=17" + "-DCMAKE_CXX_STANDARD=20" "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" INSTALL_COMMAND "" @@ -65,7 +65,7 @@ ExternalProject_Add(hexagon_launcher_binaries CMAKE_ARGS "-DCMAKE_C_COMPILER=${USE_HEXAGON_TOOLCHAIN}/bin/hexagon-clang" "-DCMAKE_CXX_COMPILER=${USE_HEXAGON_TOOLCHAIN}/bin/hexagon-clang++" - "-DCMAKE_CXX_STANDARD=17" + "-DCMAKE_CXX_STANDARD=20" "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" "-DUSE_CUSTOM_LOGGING=ON" diff --git a/apps/hexagon_launcher/cmake/android/CMakeLists.txt b/apps/hexagon_launcher/cmake/android/CMakeLists.txt index 0846ce786909..e58f87767d19 100644 --- a/apps/hexagon_launcher/cmake/android/CMakeLists.txt +++ b/apps/hexagon_launcher/cmake/android/CMakeLists.txt @@ -72,7 +72,7 @@ ExternalProject_Add(android_tvm_runtime CMAKE_ARGS "-DANDROID_ABI=${ANDROID_ABI}" "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" - "-DCMAKE_CXX_STANDARD=17" + "-DCMAKE_CXX_STANDARD=20" "-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE}" "-DUSE_HEXAGON=ON" "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" diff --git a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt index a0557307ba50..4686fc9d5849 100644 --- a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt +++ b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt @@ -83,7 +83,7 @@ ExternalProject_Add(static_hexagon_tvm_runtime "-DBUILD_STATIC_RUNTIME=ON" "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" - "-DCMAKE_CXX_STANDARD=17" + "-DCMAKE_CXX_STANDARD=20" "-DUSE_HEXAGON=ON" "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" diff --git a/cmake/utils/FindLLVM.cmake b/cmake/utils/FindLLVM.cmake index 2bf229eca756..1f54ded1d5e9 100644 --- a/cmake/utils/FindLLVM.cmake +++ b/cmake/utils/FindLLVM.cmake @@ -254,9 +254,9 @@ macro(find_llvm use_llvm) # compiler-appropriate form so the probe works under MSVC as well. if(NOT CMAKE_CXX_STANDARD) if(MSVC) - set(CMAKE_REQUIRED_FLAGS "/std:c++17") + set(CMAKE_REQUIRED_FLAGS "/std:c++20") else() - set(CMAKE_REQUIRED_FLAGS "-std=c++17") + set(CMAKE_REQUIRED_FLAGS "-std=c++20") endif() endif() check_cxx_source_compiles(" diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index a970bf5c1e9e..c3ad4da37c56 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -35,11 +35,14 @@ Apache TVM requires the following dependencies: - CMake (>= 3.24.0) - LLVM (recommended >= 15) - Git -- A recent C++ compiler supporting C++ 17, at the minimum - - GCC 7.1 - - Clang 5.0 - - Apple Clang 9.3 - - Visual Studio 2019 (v16.7) +- A recent C++ compiler supporting C++ 20, at the minimum + - GCC 10 + - Clang 10 + - Apple Clang 14 + - Visual Studio 2022 + Optional dependencies that use newer C++20 standard library facilities, such + as ``std::format``, may require a newer standard library (for example GCC 13 + or newer on Linux). - Python (>= 3.10) - (Optional) Conda (Strongly Recommended) diff --git a/jvm/native/linux-x86_64/pom.xml b/jvm/native/linux-x86_64/pom.xml index 31e120bc58cd..9a29c64e5bae 100644 --- a/jvm/native/linux-x86_64/pom.xml +++ b/jvm/native/linux-x86_64/pom.xml @@ -114,7 +114,7 @@ under the License. - -std=c++17 + -std=c++20 -I../../../include diff --git a/jvm/native/osx-x86_64/pom.xml b/jvm/native/osx-x86_64/pom.xml index 4f9d70d60dc1..c6133d925d28 100644 --- a/jvm/native/osx-x86_64/pom.xml +++ b/jvm/native/osx-x86_64/pom.xml @@ -115,7 +115,7 @@ under the License. - -std=c++17 + -std=c++20 -I../../../include diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 4ff3f0812a3b..93ccf578236e 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -76,7 +76,7 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False): "-Xcompiler=-fno-strict-aliasing", "-Xcompiler=-fvisibility=hidden", "-O3", - "-std=c++17", + "-std=c++20", f"-I{cutlass_include}", f"-I{cutlass_util_include}", f"-I{cutlass_attention_include}", diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 477c1ee44953..bf6671b30b68 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -395,7 +395,7 @@ def __init__(self, cuda_arch, cutlass_path, binary_prefix): self.cuda_arch = cuda_arch self.binary_prefix = binary_prefix self.cutlass = cutlass_path - self.cflags = f"-I{cutlass_path}/include -I{cutlass_path}/tools/util/include -O3 -std=c++17" + self.cflags = f"-I{cutlass_path}/include -I{cutlass_path}/tools/util/include -O3 -std=c++20" self.cflags += " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1" self.cflags += ( f" -gencode=arch=compute_{cuda_arch},code=[sm_{cuda_arch},compute_{cuda_arch}]" diff --git a/python/tvm/relax/frontend/nn/extern.py b/python/tvm/relax/frontend/nn/extern.py index e424554367b4..b467b6271b9f 100644 --- a/python/tvm/relax/frontend/nn/extern.py +++ b/python/tvm/relax/frontend/nn/extern.py @@ -327,7 +327,7 @@ def get_compile_options( ) -> list[str]: """Returns the default compile options depending on `source_format`, including the default inlcude paths w.r.t. `tvm_home()`, and by default, - it uses "-O3" and "-std=c++17". + it uses "-O3" and "-std=c++20". Parameters ---------- @@ -350,13 +350,13 @@ def get_compile_options( host_flags = [ "-c", # generate object file "-O3", - "-std=c++17", + "-std=c++20", ] elif source_format == "cu": host_flags = [ "-c", # generate object file "-O3", - "-std=c++17", + "-std=c++20", # Enable `-fPIC` for the host compiler "-Xcompiler=-fPIC", ] diff --git a/python/tvm/rpc/minrpc.py b/python/tvm/rpc/minrpc.py index d46f2a2faf80..4c1132af9c63 100644 --- a/python/tvm/rpc/minrpc.py +++ b/python/tvm/rpc/minrpc.py @@ -70,7 +70,7 @@ def with_minrpc(compile_func, server="posix_popen_server"): runtime_dir = os.path.abspath(os.path.dirname(runtime_path)) tvm_ffi_dir = os.path.abspath(os.path.dirname(tvm_ffi_path)) - options = ["-std=c++17"] + options = ["-std=c++20"] # Make sure the rpath to the libtvm_runtime is set so we can do local tests. # Note that however, this approach won't work on remote. # Always recommend to link statically. diff --git a/python/tvm/support/emcc.py b/python/tvm/support/emcc.py index 9bd6d24036d5..944ee8424281 100644 --- a/python/tvm/support/emcc.py +++ b/python/tvm/support/emcc.py @@ -85,7 +85,7 @@ def create_tvmjs_wasm(output, objects, options=None, cc="emcc", libs=None): """ cmd = [cc] cmd += ["-O3"] - cmd += ["-std=c++17"] + cmd += ["-std=c++20"] cmd += ["--no-entry"] # NOTE: asynctify conflicts with wasm-exception # so we temp disable exception handling for now diff --git a/tests/python/relax/test_frontend_nn_extern_module.py b/tests/python/relax/test_frontend_nn_extern_module.py index dba87c3fde36..e504b649d044 100644 --- a/tests/python/relax/test_frontend_nn_extern_module.py +++ b/tests/python/relax/test_frontend_nn_extern_module.py @@ -136,7 +136,7 @@ def _compile_cc(src: Path, dst: Path): cmd += ["-I", include_path] cmd += [ "-c", - "-std=c++17", + "-std=c++20", "-fPIC", "-o", str(dst), diff --git a/web/Makefile b/web/Makefile index 7f802b5a2152..2e4326a8b58a 100644 --- a/web/Makefile +++ b/web/Makefile @@ -28,7 +28,7 @@ all: dist/wasm/tvmjs_runtime.wasm dist/wasm/tvmjs_runtime.wasi.js src/tvmjs_runt EMCC = emcc -EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++17 -Wno-ignored-attributes +EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++20 -Wno-ignored-attributes EMCC_LDFLAGS = --no-entry -s WASM_BIGINT=1 -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1\ -s ERROR_ON_UNDEFINED_SYMBOLS=0 --pre-js emcc/preload.js\ From 9ee251f08502318da0357ce87b5c618c92a23f94 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Thu, 11 Jun 2026 12:34:52 -0400 Subject: [PATCH 02/21] [CMAKE] Keep downstream helper standards unchanged Limit the C++20 baseline update to TVM's core CMake configuration and docs. Leave app, JVM, web, and Python helper compile flags at their existing C++17 settings so those downstream build surfaces can migrate separately. --- apps/android_rpc/app/src/main/jni/Application.mk | 2 +- apps/hexagon_api/CMakeLists.txt | 6 +++--- apps/hexagon_launcher/CMakeLists.txt | 4 ++-- apps/hexagon_launcher/cmake/android/CMakeLists.txt | 2 +- apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt | 2 +- jvm/native/linux-x86_64/pom.xml | 2 +- jvm/native/osx-x86_64/pom.xml | 2 +- python/tvm/contrib/cutlass/build.py | 2 +- python/tvm/contrib/cutlass/gen_tensor_op.py | 2 +- python/tvm/relax/frontend/nn/extern.py | 6 +++--- python/tvm/rpc/minrpc.py | 2 +- python/tvm/support/emcc.py | 2 +- tests/python/relax/test_frontend_nn_extern_module.py | 2 +- web/Makefile | 2 +- 14 files changed, 19 insertions(+), 19 deletions(-) diff --git a/apps/android_rpc/app/src/main/jni/Application.mk b/apps/android_rpc/app/src/main/jni/Application.mk index bc410416ce9f..a7996548eb4d 100644 --- a/apps/android_rpc/app/src/main/jni/Application.mk +++ b/apps/android_rpc/app/src/main/jni/Application.mk @@ -31,7 +31,7 @@ include $(config) APP_ABI ?= armeabi-v7a arm64-v8a x86 x86_64 mips APP_STL := c++_shared -APP_CPPFLAGS += -DTVM4J_ANDROID=1 -std=c++20 -Oz -frtti +APP_CPPFLAGS += -DTVM4J_ANDROID=1 -std=c++17 -Oz -frtti ifeq ($(USE_OPENCL), 1) APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 endif diff --git a/apps/hexagon_api/CMakeLists.txt b/apps/hexagon_api/CMakeLists.txt index fc2be3a61a5d..62dca9d4e644 100644 --- a/apps/hexagon_api/CMakeLists.txt +++ b/apps/hexagon_api/CMakeLists.txt @@ -45,7 +45,7 @@ ExternalProject_Add(x86_tvm_runtime_rpc "-DCMAKE_C_COMPILER_LAUNCHER=${CMAKE_C_COMPILER_LAUNCHER}" "-DCMAKE_CXX_COMPILER_LAUNCHER=${CMAKE_CXX_COMPILER_LAUNCHER}" "-DUSE_HEXAGON_TOOLCHAIN=${USE_HEXAGON_TOOLCHAIN}" - "-DCMAKE_CXX_STANDARD=20" + "-DCMAKE_CXX_STANDARD=17" "-DTVM_FFI_USE_LIBBACKTRACE=OFF" "-DTVM_FFI_USE_THREADS=OFF" "-DTVM_FFI_USE_DL_LIBS=OFF" @@ -81,7 +81,7 @@ ExternalProject_Add(android_tvm_runtime_rpc "-DANDROID_ABI=${ANDROID_ABI}" "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" - "-DCMAKE_CXX_STANDARD=20" + "-DCMAKE_CXX_STANDARD=17" "-DTVM_FFI_USE_LIBBACKTRACE=OFF" "-DTVM_FFI_USE_THREADS=OFF" "-DTVM_FFI_USE_DL_LIBS=OFF" @@ -135,7 +135,7 @@ ExternalProject_Add(hexagon_tvm_runtime_rpc "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" "-DUSE_HEXAGON_EXTERNAL_LIBS=${USE_HEXAGON_EXTERNAL_LIBS}" "-DHEXAGON_EXTERNAL_LIBS_SHA=${HEXAGON_EXTERNAL_LIBS_SHA}" - "-DCMAKE_CXX_STANDARD=20" + "-DCMAKE_CXX_STANDARD=17" "-DTVM_FFI_USE_LIBBACKTRACE=OFF" "-DTVM_FFI_USE_THREADS=OFF" "-DTVM_FFI_USE_DL_LIBS=OFF" diff --git a/apps/hexagon_launcher/CMakeLists.txt b/apps/hexagon_launcher/CMakeLists.txt index b42bdc324bd6..c08e743a2592 100644 --- a/apps/hexagon_launcher/CMakeLists.txt +++ b/apps/hexagon_launcher/CMakeLists.txt @@ -44,7 +44,7 @@ ExternalProject_Add(android_launcher_binaries "-DCMAKE_TOOLCHAIN_FILE=${USE_ANDROID_TOOLCHAIN}" "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" "-DANDROID_ABI=${ANDROID_ABI}" - "-DCMAKE_CXX_STANDARD=20" + "-DCMAKE_CXX_STANDARD=17" "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" INSTALL_COMMAND "" @@ -65,7 +65,7 @@ ExternalProject_Add(hexagon_launcher_binaries CMAKE_ARGS "-DCMAKE_C_COMPILER=${USE_HEXAGON_TOOLCHAIN}/bin/hexagon-clang" "-DCMAKE_CXX_COMPILER=${USE_HEXAGON_TOOLCHAIN}/bin/hexagon-clang++" - "-DCMAKE_CXX_STANDARD=20" + "-DCMAKE_CXX_STANDARD=17" "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" "-DUSE_CUSTOM_LOGGING=ON" diff --git a/apps/hexagon_launcher/cmake/android/CMakeLists.txt b/apps/hexagon_launcher/cmake/android/CMakeLists.txt index e58f87767d19..0846ce786909 100644 --- a/apps/hexagon_launcher/cmake/android/CMakeLists.txt +++ b/apps/hexagon_launcher/cmake/android/CMakeLists.txt @@ -72,7 +72,7 @@ ExternalProject_Add(android_tvm_runtime CMAKE_ARGS "-DANDROID_ABI=${ANDROID_ABI}" "-DANDROID_PLATFORM=${ANDROID_PLATFORM}" - "-DCMAKE_CXX_STANDARD=20" + "-DCMAKE_CXX_STANDARD=17" "-DCMAKE_TOOLCHAIN_FILE=${CMAKE_TOOLCHAIN_FILE}" "-DUSE_HEXAGON=ON" "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" diff --git a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt index 4686fc9d5849..a0557307ba50 100644 --- a/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt +++ b/apps/hexagon_launcher/cmake/hexagon/CMakeLists.txt @@ -83,7 +83,7 @@ ExternalProject_Add(static_hexagon_tvm_runtime "-DBUILD_STATIC_RUNTIME=ON" "-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}" "-DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}" - "-DCMAKE_CXX_STANDARD=20" + "-DCMAKE_CXX_STANDARD=17" "-DUSE_HEXAGON=ON" "-DUSE_HEXAGON_ARCH=${USE_HEXAGON_ARCH}" "-DUSE_HEXAGON_SDK=${USE_HEXAGON_SDK}" diff --git a/jvm/native/linux-x86_64/pom.xml b/jvm/native/linux-x86_64/pom.xml index 9a29c64e5bae..31e120bc58cd 100644 --- a/jvm/native/linux-x86_64/pom.xml +++ b/jvm/native/linux-x86_64/pom.xml @@ -114,7 +114,7 @@ under the License. - -std=c++20 + -std=c++17 -I../../../include diff --git a/jvm/native/osx-x86_64/pom.xml b/jvm/native/osx-x86_64/pom.xml index c6133d925d28..4f9d70d60dc1 100644 --- a/jvm/native/osx-x86_64/pom.xml +++ b/jvm/native/osx-x86_64/pom.xml @@ -115,7 +115,7 @@ under the License. - -std=c++20 + -std=c++17 -I../../../include diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 93ccf578236e..4ff3f0812a3b 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -76,7 +76,7 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False): "-Xcompiler=-fno-strict-aliasing", "-Xcompiler=-fvisibility=hidden", "-O3", - "-std=c++20", + "-std=c++17", f"-I{cutlass_include}", f"-I{cutlass_util_include}", f"-I{cutlass_attention_include}", diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index bf6671b30b68..477c1ee44953 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -395,7 +395,7 @@ def __init__(self, cuda_arch, cutlass_path, binary_prefix): self.cuda_arch = cuda_arch self.binary_prefix = binary_prefix self.cutlass = cutlass_path - self.cflags = f"-I{cutlass_path}/include -I{cutlass_path}/tools/util/include -O3 -std=c++20" + self.cflags = f"-I{cutlass_path}/include -I{cutlass_path}/tools/util/include -O3 -std=c++17" self.cflags += " -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1" self.cflags += ( f" -gencode=arch=compute_{cuda_arch},code=[sm_{cuda_arch},compute_{cuda_arch}]" diff --git a/python/tvm/relax/frontend/nn/extern.py b/python/tvm/relax/frontend/nn/extern.py index b467b6271b9f..e424554367b4 100644 --- a/python/tvm/relax/frontend/nn/extern.py +++ b/python/tvm/relax/frontend/nn/extern.py @@ -327,7 +327,7 @@ def get_compile_options( ) -> list[str]: """Returns the default compile options depending on `source_format`, including the default inlcude paths w.r.t. `tvm_home()`, and by default, - it uses "-O3" and "-std=c++20". + it uses "-O3" and "-std=c++17". Parameters ---------- @@ -350,13 +350,13 @@ def get_compile_options( host_flags = [ "-c", # generate object file "-O3", - "-std=c++20", + "-std=c++17", ] elif source_format == "cu": host_flags = [ "-c", # generate object file "-O3", - "-std=c++20", + "-std=c++17", # Enable `-fPIC` for the host compiler "-Xcompiler=-fPIC", ] diff --git a/python/tvm/rpc/minrpc.py b/python/tvm/rpc/minrpc.py index 4c1132af9c63..d46f2a2faf80 100644 --- a/python/tvm/rpc/minrpc.py +++ b/python/tvm/rpc/minrpc.py @@ -70,7 +70,7 @@ def with_minrpc(compile_func, server="posix_popen_server"): runtime_dir = os.path.abspath(os.path.dirname(runtime_path)) tvm_ffi_dir = os.path.abspath(os.path.dirname(tvm_ffi_path)) - options = ["-std=c++20"] + options = ["-std=c++17"] # Make sure the rpath to the libtvm_runtime is set so we can do local tests. # Note that however, this approach won't work on remote. # Always recommend to link statically. diff --git a/python/tvm/support/emcc.py b/python/tvm/support/emcc.py index 944ee8424281..9bd6d24036d5 100644 --- a/python/tvm/support/emcc.py +++ b/python/tvm/support/emcc.py @@ -85,7 +85,7 @@ def create_tvmjs_wasm(output, objects, options=None, cc="emcc", libs=None): """ cmd = [cc] cmd += ["-O3"] - cmd += ["-std=c++20"] + cmd += ["-std=c++17"] cmd += ["--no-entry"] # NOTE: asynctify conflicts with wasm-exception # so we temp disable exception handling for now diff --git a/tests/python/relax/test_frontend_nn_extern_module.py b/tests/python/relax/test_frontend_nn_extern_module.py index e504b649d044..dba87c3fde36 100644 --- a/tests/python/relax/test_frontend_nn_extern_module.py +++ b/tests/python/relax/test_frontend_nn_extern_module.py @@ -136,7 +136,7 @@ def _compile_cc(src: Path, dst: Path): cmd += ["-I", include_path] cmd += [ "-c", - "-std=c++20", + "-std=c++17", "-fPIC", "-o", str(dst), diff --git a/web/Makefile b/web/Makefile index 2e4326a8b58a..7f802b5a2152 100644 --- a/web/Makefile +++ b/web/Makefile @@ -28,7 +28,7 @@ all: dist/wasm/tvmjs_runtime.wasm dist/wasm/tvmjs_runtime.wasi.js src/tvmjs_runt EMCC = emcc -EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++20 -Wno-ignored-attributes +EMCC_CFLAGS = $(INCLUDE_FLAGS) -O3 -std=c++17 -Wno-ignored-attributes EMCC_LDFLAGS = --no-entry -s WASM_BIGINT=1 -s ALLOW_MEMORY_GROWTH=1 -s STANDALONE_WASM=1\ -s ERROR_ON_UNDEFINED_SYMBOLS=0 --pre-js emcc/preload.js\ From 6cf6ecb7f05fd708a33b89604ca6e90578352630 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 14 Jun 2026 06:39:26 -0400 Subject: [PATCH 03/21] [CMAKE] Fix C++20 warning errors Replace deprecated C++ traits and make captures explicit so the C++20 baseline builds cleanly under -Werror. --- include/tvm/tirx/op.h | 3 ++- src/relax/transform/run_codegen.cc | 2 +- src/runtime/extra/contrib/cudnn/cudnn_json_runtime.cc | 4 ++-- src/runtime/extra/disco/protocol.h | 4 +++- src/runtime/rpc/rpc_endpoint.cc | 4 +++- src/s_tir/transform/compact_buffer_region.cc | 2 ++ src/s_tir/transform/inject_software_pipeline.cc | 2 +- src/target/llvm/codegen_params.cc | 3 ++- 8 files changed, 16 insertions(+), 8 deletions(-) diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h index 60b292bbb265..231b129b94be 100644 --- a/include/tvm/tirx/op.h +++ b/include/tvm/tirx/op.h @@ -825,7 +825,8 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) { * \param span The location of this operation in the source. */ template ::value>::type> + typename = typename std::enable_if::value && + std::is_trivial::value>::type> inline PrimExpr make_const(DataType t, ValueType value, Span span = Span()); /*! * \brief Make a const zero expr. diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index bc99196169e7..efd90d6696d7 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -155,7 +155,7 @@ class CodeGenRunner : ExprMutator { if (opt_codegen) { auto ext_symbol = GetExtSymbol(func); size_t count = 0; - PostOrderVisit(func->body, [=, &count](Expr e) { + PostOrderVisit(func->body, [=, this, &count](Expr e) { if (e->IsInstance()) { // Make sure to pick a unique name auto name = ext_symbol + "_" + opt_codegen.value() + "_const_" + std::to_string(count++); diff --git a/src/runtime/extra/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/extra/contrib/cudnn/cudnn_json_runtime.cc index df38f960d294..a7cf1ec2b318 100644 --- a/src/runtime/extra/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/extra/contrib/cudnn/cudnn_json_runtime.cc @@ -162,7 +162,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { conv_dtype, false, &best_algo); int algo = best_algo.cast(); - std::function op_exec = [=]() { + std::function op_exec = [=, this]() { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); @@ -223,7 +223,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { auto runner = tvm::contrib::CuDNNSDPARunner::Create(); runner->Init(batch, seq_len, num_heads, num_kv_heads, head_size, head_size_v, scale, dtype, layout); - return [=]() { + return [=, this]() { auto qkv = GetInput(node, 0); auto workspace = const_cast(GetInput(node, 1)); auto out = const_cast(data_entry_[EntryID(outputs_[0])]); diff --git a/src/runtime/extra/disco/protocol.h b/src/runtime/extra/disco/protocol.h index 25662051dcb4..a26b3060bc2a 100644 --- a/src/runtime/extra/disco/protocol.h +++ b/src/runtime/extra/disco/protocol.h @@ -28,6 +28,7 @@ #include #include +#include #include #include @@ -78,7 +79,8 @@ struct DiscoProtocol { /*!\ brief Arena used by RPCReference to allocate POD memory */ template T* ArenaAlloc(int count) { - static_assert(std::is_pod::value, "need to be trival"); + static_assert(std::is_standard_layout::value && std::is_trivial::value, + "need to be trivial"); return arena_.template allocate_(count); } diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index a6950117d611..0402430251e5 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -312,7 +313,8 @@ class RPCEndpoint::EventHandler : public support::Stream { template T* ArenaAlloc(int count) { - static_assert(std::is_pod::value, "need to be trival"); + static_assert(std::is_standard_layout::value && std::is_trivial::value, + "need to be trivial"); return arena_.template allocate_(count); } diff --git a/src/s_tir/transform/compact_buffer_region.cc b/src/s_tir/transform/compact_buffer_region.cc index d02e90701696..640a18d594cf 100644 --- a/src/s_tir/transform/compact_buffer_region.cc +++ b/src/s_tir/transform/compact_buffer_region.cc @@ -129,6 +129,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } private: + using StmtExprVisitor::VisitBufferDef; + struct BufferAccessInfo { /*! \brief The buffer. */ Buffer buffer; diff --git a/src/s_tir/transform/inject_software_pipeline.cc b/src/s_tir/transform/inject_software_pipeline.cc index d9da151f392f..5099f66cd030 100644 --- a/src/s_tir/transform/inject_software_pipeline.cc +++ b/src/s_tir/transform/inject_software_pipeline.cc @@ -707,7 +707,7 @@ class PipelineRewriter : public StmtExprMutator { } } - auto wait_count = [=, &ana_normalized]() { + auto wait_count = [=, this, &ana_normalized]() { auto sum = PrimExpr(0); for (auto producer_head : producer_head_per_commit) { if (producer_head && ana_normalized->CanProve(producer_head.value() >= 0)) { diff --git a/src/target/llvm/codegen_params.cc b/src/target/llvm/codegen_params.cc index fccc92a22830..6d8684a87eda 100644 --- a/src/target/llvm/codegen_params.cc +++ b/src/target/llvm/codegen_params.cc @@ -61,7 +61,8 @@ struct LLVMConstantGetter::value>> static llvm::Constant* getElement(llvm::Type* ty, T t) { return llvm::ConstantFP::get(ty, t); } }; -template ::value>> +template ::value && std::is_trivial::value>> void BuildLLVMVector(llvm::Type* element_type, void* tensor_data, size_t num_elements, std::vector* elements) { elements->resize(num_elements, nullptr); From ca968ba3dafdaa8318a22fc178ba685075e3f30f Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 14 Jun 2026 19:50:04 -0400 Subject: [PATCH 04/21] [DOCS] Fix from source dependency list formatting --- docs/install/from_source.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index c3ad4da37c56..392fdc5cfc5e 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -40,6 +40,7 @@ Apache TVM requires the following dependencies: - Clang 10 - Apple Clang 14 - Visual Studio 2022 + Optional dependencies that use newer C++20 standard library facilities, such as ``std::format``, may require a newer standard library (for example GCC 13 or newer on Linux). From a98a694a060b448c6d50f2d55558204cc955df57 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Wed, 3 Jun 2026 16:48:20 -0400 Subject: [PATCH 05/21] [ARITH] Add optional Z3-backed proving to Analyzer Add an optional Z3 SMT solver backend to tvm::arith::Analyzer for stronger integer arithmetic proving. The integration is guarded by a new USE_Z3 CMake option (default OFF). When enabled, Analyzer::CanProve runs the existing analysis path first and only falls back to Z3 when the existing analyzers cannot prove the predicate. When disabled, a stub implementation keeps the C++ and Python APIs available without Z3. --- CMakeLists.txt | 2 + cmake/modules/contrib/Z3.cmake | 76 +++ include/tvm/arith/analyzer.h | 125 ++++- python/tvm/arith/analyzer.py | 40 ++ src/arith/analyzer.cc | 25 +- src/arith/rewrite_simplify.cc | 10 +- src/arith/rewrite_simplify.h | 2 +- src/target/z3/z3_prover_off.cc | 38 ++ src/target/z3/z3_prover_on.cc | 788 ++++++++++++++++++++++++++++ tests/python/arith/test_arith_z3.py | 92 ++++ 10 files changed, 1188 insertions(+), 10 deletions(-) create mode 100644 cmake/modules/contrib/Z3.cmake create mode 100644 src/target/z3/z3_prover_off.cc create mode 100644 src/target/z3/z3_prover_on.cc create mode 100644 tests/python/arith/test_arith_z3.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 00632f8d71d9..b57ef919feb3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,6 +90,7 @@ tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt") tvm_option(USE_BYODT_POSIT "Build with BYODT software emulated posit custom datatype" OFF) tvm_option(USE_BLAS "The blas library to be linked" none) tvm_option(USE_AMX "Enable Intel AMX" OFF) +tvm_option(USE_Z3 "Build with Z3 SMT solver support" OFF) tvm_option(USE_MKL "MKL root path when use MKL blas" OFF) tvm_option(USE_DNNL "Enable DNNL codegen" OFF) tvm_option(USE_CUDNN "Build with cuDNN" OFF) @@ -460,6 +461,7 @@ include(cmake/modules/contrib/CUTLASS.cmake) include(cmake/modules/contrib/Random.cmake) include(cmake/modules/contrib/Posit.cmake) include(cmake/modules/contrib/Sort.cmake) +include(cmake/modules/contrib/Z3.cmake) include(cmake/modules/contrib/CoreML.cmake) include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/contrib/NNAPI.cmake) diff --git a/cmake/modules/contrib/Z3.cmake b/cmake/modules/contrib/Z3.cmake new file mode 100644 index 000000000000..eef62e4cfcd8 --- /dev/null +++ b/cmake/modules/contrib/Z3.cmake @@ -0,0 +1,76 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +if(NOT USE_Z3) + list(APPEND COMPILER_SRCS src/target/z3/z3_prover_off.cc) + return() +endif() + +find_package(Z3 QUIET) +set(Z3_PYTHON_RESULT 1) + +if(NOT Z3_FOUND) + find_package(Python3 COMPONENTS Interpreter QUIET) + if(Python3_EXECUTABLE) + execute_process( + COMMAND "${Python3_EXECUTABLE}" -c "import z3; print(z3.__path__[0])" + OUTPUT_VARIABLE Z3_PYTHON_PACKAGE_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE Z3_PYTHON_RESULT + ) + endif() + + if(Z3_PYTHON_RESULT EQUAL 0 AND NOT Z3_PYTHON_PACKAGE_DIR STREQUAL "") + find_path(Z3_INCLUDE_DIR NO_DEFAULT_PATH NAMES z3++.h PATHS "${Z3_PYTHON_PACKAGE_DIR}/include") + find_library( + Z3_LIBRARY + NO_DEFAULT_PATH + NAMES z3 libz3 + PATHS "${Z3_PYTHON_PACKAGE_DIR}/bin" "${Z3_PYTHON_PACKAGE_DIR}/lib" + "${Z3_PYTHON_PACKAGE_DIR}/lib64" + ) + endif() +endif() + +if(TARGET z3::libz3 OR TARGET Z3::libz3) + if(TARGET z3::libz3) + set(Z3_TARGET z3::libz3) + else() + set(Z3_TARGET Z3::libz3) + endif() + get_target_property(Z3_TARGET_INCLUDE_DIRS ${Z3_TARGET} INTERFACE_INCLUDE_DIRECTORIES) + if(Z3_TARGET_INCLUDE_DIRS) + include_directories(SYSTEM ${Z3_TARGET_INCLUDE_DIRS}) + endif() + list(APPEND TVM_LINKER_LIBS ${Z3_TARGET}) +elseif(Z3_FOUND OR (Z3_INCLUDE_DIR AND Z3_LIBRARY)) + if(NOT Z3_INCLUDE_DIR AND Z3_CXX_INCLUDE_DIRS) + set(Z3_INCLUDE_DIR ${Z3_CXX_INCLUDE_DIRS}) + endif() + if(NOT Z3_LIBRARY AND Z3_LIBRARIES) + set(Z3_LIBRARY ${Z3_LIBRARIES}) + endif() + if(NOT Z3_INCLUDE_DIR OR NOT Z3_LIBRARY) + message(FATAL_ERROR "USE_Z3 is ON, but Z3 include directory or library was not found.") + endif() + include_directories(SYSTEM ${Z3_INCLUDE_DIR}) + list(APPEND TVM_LINKER_LIBS ${Z3_LIBRARY}) +else() + message(FATAL_ERROR "USE_Z3 is ON, but Z3 was not found. Install Z3 or PyPI z3-solver.") +endif() + +list(APPEND COMPILER_SRCS src/target/z3/z3_prover_on.cc) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 924cc299270a..7de9213f8114 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -299,7 +300,7 @@ class RewriteSimplifier { * * \return an exit function that must be called to cleanup the constraint can be nullptr. */ - TVM_DLL std::function EnterConstraint(const PrimExpr& constraint); + TVM_DLL std::function EnterConstraint(const PrimExpr& constraint, bool is_assume = false); /*! \brief Flags to enable more computationally-intensive simplifications * @@ -588,6 +589,103 @@ class IntSetAnalyzer { Impl* impl_; }; +class Z3Prover { + public: + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param new_range The range of allowed values for this var. + * \param allow_override whether we allow override of existing information. + */ + TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false); + + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param expr The bound expression. + * \param allow_override whether we allow override of existing information. + */ + TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! + * \brief Whether can we prove expr is always true. + * + * \param expr The expression. + * \return Whether we can prove it. + */ + TVM_DLL bool CanProve(const PrimExpr& expr); + + /*! + * \brief Update the internal state to enter constraint. + * + * \param constraint A constraint expression. + * \param is_assume Whether the constraint comes from an assumption. + * \return an exit function that must be called to cleanup the constraint can be nullptr. + */ + std::function EnterConstraint(const PrimExpr& constraint, bool is_assume = false); + + /*! + * \brief Get the SMTLIB2 representation of the current context. + * + * \param expr The optional expression to check. + * \return The SMTLIB2 string. + */ + ffi::String GetSMTLIB2(const ffi::Optional expr); + + /*! + * \brief Get statistics about Z3 prover. + * + * \return The statistics string. + */ + ffi::String GetStats(); + + /*! + * \brief Set timeout in milliseconds for Z3 prover. + * + * \param timeout_ms The timeout in milliseconds. + */ + void SetTimeoutMs(unsigned timeout_ms); + + /*! + * \brief Set resource limitation for Z3 prover. + * + * \param rlimit the resource limitation. + */ + void SetRLimit(unsigned rlimit); + + /*! + * \brief Get the Z3 model for the given expression if satisfiable. + * + * \param expr The expression to get the model for. + * \return The model as a string. + */ + ffi::String GetModel(const PrimExpr& expr); + + /*! + * \brief Count the number of integer values that satisfy the current constraints. + * + * This method uses Z3's model enumeration to count how many distinct values of + * the given variable satisfy all current constraints. + * + * \param var The variable to count satisfying values for. + * \param max_count Maximum number of solutions to enumerate. + * \param min_consecutive Minimum consecutive count requirement. + * \return The number of distinct values that satisfy the constraints, or a negative error code. + */ + TVM_DLL int64_t CountSatisfyingValues(const Var& var, int64_t max_count = 2048, + int64_t min_consecutive = 1); + + private: + friend class Analyzer; + explicit Z3Prover(AnalyzerObj* parent); + TVM_DLL ~Z3Prover(); + void CopyFrom(const Z3Prover& other); + class Impl; + Impl* impl_; +}; + /*! * \brief Analyzer that contains bunch of sub-analyzers. * @@ -612,6 +710,8 @@ class TVM_DLL AnalyzerObj : public ffi::Object { IntSetAnalyzer int_set; /*! \brief sub-analyzer transitive comparisons */ TransitiveComparisonAnalyzer transitive_comparisons; + /*! \brief sub-analyzer using Z3 */ + Z3Prover z3_prover; /*! \brief constructor */ AnalyzerObj(); /*! @@ -810,7 +910,16 @@ class ConstraintContext { * \param constraint The constraint to be applied. */ ConstraintContext(const Analyzer& analyzer, PrimExpr constraint) - : analyzer_(analyzer), constraint_(constraint) {} + : ConstraintContext(analyzer, std::move(constraint), false) {} + /*! + * \brief Construct a constraint context. + * \param analyzer The analyzer whose context is updated. The context + * keeps a reference to the analyzer while the scope is active. + * \param constraint The constraint to be applied. + * \param is_assume Whether the constraint comes from an assumption. + */ + ConstraintContext(const Analyzer& analyzer, PrimExpr constraint, bool is_assume) + : analyzer_(analyzer), constraint_(std::move(constraint)), is_assume_(is_assume) {} /*! * \brief Construct a constraint context from a borrowed analyzer object. * \param analyzer The borrowed analyzer object. @@ -819,7 +928,15 @@ class ConstraintContext { * This overload is for internal callers that already operate on AnalyzerObj*. */ ConstraintContext(AnalyzerObj* analyzer, PrimExpr constraint) - : ConstraintContext(ffi::GetRef(analyzer), std::move(constraint)) {} + : ConstraintContext(ffi::GetRef(analyzer), std::move(constraint), false) {} + /*! + * \brief Construct a constraint context from a borrowed analyzer object. + * \param analyzer The borrowed analyzer object. + * \param constraint The constraint to be applied. + * \param is_assume Whether the constraint comes from an assumption. + */ + ConstraintContext(AnalyzerObj* analyzer, PrimExpr constraint, bool is_assume) + : ConstraintContext(ffi::GetRef(analyzer), std::move(constraint), is_assume) {} // enter the scope. void EnterWithScope(); // exit the scope. @@ -830,6 +947,8 @@ class ConstraintContext { PrimExpr constraint_; /*! \brief functions to be called in recovery */ std::vector> recovery_functions_; + /*! \brief Whether the constraint comes from an assumption. */ + bool is_assume_; }; } // namespace arith diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 0aa6a75eba4a..39fa30d64180 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -128,6 +128,46 @@ class Analyzer(Object): def __init__(self): self.__init_handle_by_constructor__(_ffi_api.Analyzer) + def get_smtlib2(self, expr: tirx.PrimExpr = None) -> str: + """Get the current Z3 problem in SMT-LIB2 format. + + Parameters + ---------- + expr : Optional[PrimExpr] + The expression to prove. If provided, its negation is added to the problem. + """ + return _ffi_api.AnalyzerGetSMTLIB2(self, expr) + + def set_z3_timeout_ms(self, timeout_ms: int) -> None: + """Set Z3 timeout in milliseconds. + + Parameters + ---------- + timeout_ms : int + The timeout in milliseconds. + """ + _ffi_api.AnalyzerSetZ3TimeoutMs(self, timeout_ms) + + def set_z3_rlimit(self, rlimit: int) -> None: + """Set Z3 resource limit. + + Parameters + ---------- + rlimit : int + The resource limit. + """ + _ffi_api.AnalyzerSetZ3RLimit(self, rlimit) + + def get_z3_stats(self) -> str: + """Get Z3 solver statistics. + + Returns + ------- + stats : str + The Z3 statistics. + """ + return _ffi_api.AnalyzerGetZ3Stats(self) + def const_int_bound(self, expr: tirx.PrimExpr) -> ConstIntBound: """Find constant integer bound for expr. diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index cc3c73bb6207..cb209b0a9dfb 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -39,7 +39,8 @@ AnalyzerObj::AnalyzerObj() modular_set(this), rewrite_simplify(this), canonical_simplify(this), - int_set(this) {} + int_set(this), + z3_prover(this) {} void AnalyzerObj::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { PrimExpr new_expr = expr; @@ -52,6 +53,7 @@ void AnalyzerObj::Bind(const Var& var, const PrimExpr& expr, bool allow_override this->canonical_simplify.Update(var, new_expr, allow_override); this->int_set.Update(var, this->int_set(new_expr), allow_override); this->transitive_comparisons.Bind(var, expr, allow_override); + this->z3_prover.Bind(var, expr, allow_override); } void AnalyzerObj::Bind(const Var& var, const Range& range, bool allow_override) { @@ -62,6 +64,7 @@ void AnalyzerObj::Bind(const Var& var, const Range& range, bool allow_override) this->const_int_bound.Bind(var, range, allow_override); this->int_set.Bind(var, range, allow_override); this->transitive_comparisons.Bind(var, range, allow_override); + this->z3_prover.Bind(var, range, allow_override); } // skip modular_set // skip rewrite simplify @@ -128,9 +131,11 @@ void ConstraintContext::EnterWithScope() { // entering the scope. recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_)); recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_)); - recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_)); + recovery_functions_.push_back( + analyzer_->rewrite_simplify.EnterConstraint(constraint_, is_assume_)); recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_)); recovery_functions_.push_back(analyzer_->transitive_comparisons.EnterConstraint(constraint_)); + recovery_functions_.push_back(analyzer_->z3_prover.EnterConstraint(constraint_, is_assume_)); } void ConstraintContext::ExitWithScope() { @@ -231,6 +236,10 @@ bool AnalyzerObj::CanProve(const PrimExpr& expr, ProofStrength strength) { } } + if (z3_prover.CanProve(simplified)) { + return true; + } + } return false; } @@ -334,6 +343,18 @@ TVM_FFI_STATIC_INIT_BLOCK() { return static_cast( analyzer->transitive_comparisons.TryCompare(lhs, rhs, propagate_inequalities)); }) + .def("arith.AnalyzerGetSMTLIB2", + [](Analyzer analyzer, ffi::Optional expr) { + return analyzer->z3_prover.GetSMTLIB2(expr); + }) + .def("arith.AnalyzerSetZ3TimeoutMs", [](Analyzer analyzer, int64_t timeout_ms) { + analyzer->z3_prover.SetTimeoutMs(static_cast(timeout_ms)); + }) + .def("arith.AnalyzerSetZ3RLimit", [](Analyzer analyzer, int64_t rlimit) { + analyzer->z3_prover.SetRLimit(static_cast(rlimit)); + }) + .def("arith.AnalyzerGetZ3Stats", + [](Analyzer analyzer) { return analyzer->z3_prover.GetStats(); }) .def("arith.AnalyzerGetEnabledExtensions", [](Analyzer analyzer) { return static_cast(analyzer->rewrite_simplify.GetEnabledExtensions()); diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 2120aaa1a859..d0cffbf7950b 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -526,13 +526,14 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { return ret; } -std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) { +std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint, + bool is_assume) { size_t old_literal_size = literal_constraints_.size(); // we will compare the already simplified result with the constraint, // so simplify the constraint as well PrimExpr new_constraint = operator()(constraint); for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint, false)) { - if (SideEffect(subconstraint) <= CallEffectKind::kPure) { + if (is_assume || SideEffect(subconstraint) <= CallEffectKind::kPure) { literal_constraints_.push_back(subconstraint); PrimExpr negation; if (subconstraint.dtype().is_bool()) { @@ -2440,8 +2441,9 @@ void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool allow_ impl_->Update(var, info, allow_override); } -std::function RewriteSimplifier::EnterConstraint(const PrimExpr& constraint) { - return impl_->EnterConstraint(constraint); +std::function RewriteSimplifier::EnterConstraint(const PrimExpr& constraint, + bool is_assume) { + return impl_->EnterConstraint(constraint, is_assume); } void RewriteSimplifier::SetEnabledExtensions(Extension flags) { diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index b42b73336a27..026deee72bec 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -117,7 +117,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const CastNode* op) override; PrimExpr VisitExpr_(const LetNode* op) override; - std::function EnterConstraint(const PrimExpr& constraint); + std::function EnterConstraint(const PrimExpr& constraint, bool is_assume = false); /*! \brief Enable an optional extension or extensions * diff --git a/src/target/z3/z3_prover_off.cc b/src/target/z3/z3_prover_off.cc new file mode 100644 index 000000000000..98278ffbf39d --- /dev/null +++ b/src/target/z3/z3_prover_off.cc @@ -0,0 +1,38 @@ +#include +#include +#include + +#include "tvm/arith/analyzer.h" +#include "tvm/ffi/string.h" +#include "tvm/ir/expr.h" + +namespace tvm::arith { + +using namespace tirx; +using namespace ffi; + +class Z3Prover::Impl {}; + +TVM_DLL bool Z3Prover::CanProve(const PrimExpr& expr) { return false; } +TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool allow_override) {} +TVM_DLL void Z3Prover::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {} +std::function Z3Prover::EnterConstraint(const PrimExpr& constraint, bool is_assume) { + return []() {}; +} +ffi::String Z3Prover::GetSMTLIB2(const ffi::Optional expr) { + return "; Z3 Prover is disabled."; +} +void Z3Prover::SetTimeoutMs(unsigned timeout_ms) {} +void Z3Prover::SetRLimit(unsigned rlimit) {} +ffi::String Z3Prover::GetModel(const PrimExpr& expr) { return "; Z3 Prover is disabled."; } +TVM_DLL int64_t Z3Prover::CountSatisfyingValues(const Var& var, int64_t max_count, + int64_t min_consecutive) { + return -1; // Z3 disabled, return error +} + +void Z3Prover::CopyFrom(const Z3Prover& other) {} +ffi::String Z3Prover::GetStats() { return "; Z3 Prover is disabled."; } +Z3Prover::Z3Prover(AnalyzerObj*) : impl_(nullptr) {} +TVM_DLL Z3Prover::~Z3Prover() {} + +} // namespace tvm::arith diff --git a/src/target/z3/z3_prover_on.cc b/src/target/z3/z3_prover_on.cc new file mode 100644 index 000000000000..08b1c359605a --- /dev/null +++ b/src/target/z3/z3_prover_on.cc @@ -0,0 +1,788 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tvm/arith/analyzer.h" +#include "tvm/ffi/cast.h" +#include "tvm/ffi/object.h" +#include "tvm/ffi/string.h" +#include "tvm/ir/expr.h" +#include "tvm/runtime/data_type.h" +#include "z3++.h" + +namespace tvm::arith { + +using namespace tirx; +using namespace ffi; + +namespace { + +struct Namespace { + std::unordered_set used_names; + /// @brief Get a new name that is not used before + /// This function is used to generate z3 variable names + /// + /// Z3 may deduplicate variables with the same name, which + /// causes issues when different TVM variables are mapped to + /// the same z3 variable. + /// + /// This function generates unique names by appending + /// suffixes to the original expression string representation. + /// + /// such as : "x", "x$1", "x$2", ... + std::string GetNewName(const PrimExpr& expr) { + std::stringstream ss; + ss << expr; + auto name = ss.str(); + if (used_names.count(name) == 0) { + used_names.insert(name); + return name; + } + int idx = 1; + std::string check_name = name + "$" + std::to_string(idx); + while (used_names.count(check_name)) { + idx++; + check_name = name + "$" + std::to_string(idx); + } + used_names.insert(check_name); + return check_name; + } +}; + +} // namespace + +class Z3Prover::Impl : ExprFunctor { + public: + using Base = ExprFunctor; + using Self = Z3Prover::Impl; + + AnalyzerObj* analyzer; + /// @brief Z3 context, a shared ptr, because tilelang want to copy the Analyzer + // We use a thread_local static Z3 context so all analyzers within the same thread + // can share a common context, because Z3 initialization is slow on some CPUs + // (e.g., AMD EPYC 7502 32-Core). Using thread_local ensures thread safety. + inline static thread_local std::shared_ptr ctx{new z3::context()}; + + /// @brief Z3 solver instance + z3::solver solver{*ctx}; + + /// @brief Memorize pure expressions + std::unordered_map memo_; + + bool is_assume = false; + + /// @brief Namespace for variable naming + Namespace ns; + + /// @brief Timeout in milliseconds + unsigned timeout_ms{UINT_MAX}; + + /// @brief Max steps + unsigned rlimit{UINT_MAX}; + + /// @brief Create a z3 solver with custom options + static z3::solver CreateSolver(z3::context& ctx) { + z3::solver solver(ctx); + // here we disable model generation to speed up the solving process + solver.set("model", false); + // ensure determinstic behavior + solver.set("random_seed", (unsigned)42); + return solver; + } + + Impl(AnalyzerObj* parent) : analyzer(parent) { + scope_stack_.push_back({}); + solver = CreateSolver(*ctx); + // default timeout 5ms + // Z3's implementation of timeout, when setting timeout T ms, it will stop at T - 1 ms + // SetTimeoutMs(5); + // use rlimit, not timeout to ensure determinstic behavior + SetRLimit(1e4); + } + + /// @brief Create a Free z3 expression from PrimExprNode + z3::expr Create(const PrimExprNode* op) { + auto ref = ffi::GetRef(op); + auto dtype = op->dtype; + std::string name = ns.GetNewName(ref); + /// TVM max_val can't handle uint64 max correctly, so we special case it here + if (dtype.is_bool()) { + return ctx->bool_const(name.c_str()); + } else { + z3::expr e = ctx->int_const(name.c_str()); + if (dtype.is_uint() && dtype.bits() == 64) { + solver.add(ctx->int_val(0) <= e && e <= ctx->int_val((uint64_t)UINT64_MAX)); + } else { + auto min_val = Downcast(min_value(dtype))->value; + auto max_val = Downcast(max_value(dtype))->value; + solver.add(ctx->int_val(min_val) <= e && e <= ctx->int_val(max_val)); + } + return e; + } + } + + struct Scope { + enum Kind { + BindValue, + BindRange, + Constraint, + } kind; + Var var; + PrimExpr value; + PrimExpr min; + PrimExpr extent; + PrimExpr constraint; + }; + + /// @brief scope_stack memorizes existing constraint and bindings + /// to generate SMTLIB2 representation with comments + std::vector> scope_stack_; + + /// @brief Enter a constraint scope + std::function EnterConstraint(const PrimExpr& constraint, bool is_assume = false) { + scope_stack_.push_back({}); + scope_stack_.back().push_back( + Scope{Scope::Constraint, Var(), PrimExpr(), PrimExpr(), PrimExpr(), constraint}); + solver.push(); + this->is_assume = is_assume; + solver.add(VisitBool(constraint)); + this->is_assume = false; + auto side_effect_exprs = std::move(side_effect_exprs_); + side_effect_exprs_.clear(); + if (is_assume) { + return [this, side_effect_exprs]() { + solver.pop(); + for (const auto& expr : side_effect_exprs) { + memo_.erase(expr); + } + scope_stack_.pop_back(); + }; + } else { + for (const auto& expr : side_effect_exprs) { + memo_.erase(expr); + } + return [this]() { + solver.pop(); + scope_stack_.pop_back(); + }; + } + } + + /// @brief Check trivil bad cases, return true if the expr is a bad case + /// Z3 prover may take a long time to initialize (at least 200us), + /// This optimization can speedup 30% of the test cases in our unit tests + bool CheckTrivilBadCases(const PrimExpr& expr) { + if (IsFreeNode(expr)) { + return true; + } + auto checkTrivilCmp = [this](const PrimExpr& lhs, const PrimExpr& rhs) { + if (IsFreeNode(lhs) && rhs->IsInstance()) { + return true; + } + if (IsFreeNode(rhs) && lhs->IsInstance()) { + return true; + } + if (IsFreeNode(lhs) && IsFreeNode(rhs)) { + return true; + } + // cast('xxx', free_var) == constant + if (auto cast = lhs.as()) { + if (IsFreeNode(cast->value) && rhs->IsInstance()) { + return true; + } + } + // constant == cast('xxx', free_var) + if (auto cast = rhs.as()) { + if (IsFreeNode(cast->value) && lhs->IsInstance()) { + return true; + } + } + return false; + }; + if (auto eq = expr.as()) { + auto lhs = eq->a; + auto rhs = eq->b; + return checkTrivilCmp(lhs, rhs); + } else if (auto ne = expr.as()) { + auto lhs = ne->a; + auto rhs = ne->b; + return checkTrivilCmp(lhs, rhs); + } + return false; + } + + /// @brief Check if the expression can be proved + bool CanProve(const PrimExpr& expr) { + if (CheckTrivilBadCases(expr)) return false; + if (!IsValidDType(expr->dtype)) return false; + z3::expr_vector constr(*ctx); + constr.push_back(!ConvertBool(expr)); + auto result = solver.check(constr); + constr.pop_back(); + return result == z3::unsat; + } + + /// @brief Binded + /// @brief Bind a variable to a value or a range + void Bind(const Var& var, const PrimExpr& value, bool allow_override = false) { + if (!IsValidDType(var->dtype)) return; + scope_stack_.back().push_back(Scope{Scope::BindValue, var, value}); + // we add the binding whenever the value is pure, + // because non-pure parts are handling by creating free variables in VisitExpr + memo_.emplace(var, ConvertInt(value)); + } + + /// @brief Bind a variable to a range + void Bind(const Var& var, const Range& range, bool allow_override = false) { + if (!IsValidDType(var->dtype)) return; + scope_stack_.back().push_back( + Scope{Scope::BindRange, var, PrimExpr(), range->min, range->extent}); + // 1. Create a placeholder for the var, and save it in the memo + // if the var is overrided later, we can just update the memo, and the old placeholder will + // be ignored + auto var_expr = Create(var.as()); + memo_.emplace(var, var_expr); + + // 2. Add constraint on the placeholder + // when min_expr >= max_expr, the range is empty, which is under undefined behavior + // instead of adding an unsat constraint, we just skip the range constraint to leave it a + // free var + if (tirx::is_const_int(range->min) && tirx::is_const_int(range->min + range->extent)) { + int64_t min_value = *tirx::as_const_int(range->min); + int64_t max_value = *tirx::as_const_int(range->min + range->extent); + if (min_value < max_value) { + solver.add(ctx->int_val(min_value) <= var_expr); + solver.add(var_expr < ctx->int_val(max_value)); + } + } else { + solver.add(ConvertBool(range->extent <= 0 || + (range->min <= var && var < range->min + range->extent))); + } + } + + void CopyFrom(const Self& other_) { + // 1. create a new solver + // because this->solver depends on this->ctx + // we need to deconstruct the old solver, and create a new one depending on other_.ctx + solver = CreateSolver(*other_.ctx); + // 2. copy the context + // the context is a shared_ptr, we can just copy the pointer + ctx = other_.ctx; + // 3. copy other objects + ns = other_.ns; + for (auto& item : other_.memo_) { + memo_.emplace(item.first, item.second); + } + for (auto a : other_.solver.assertions()) { + solver.add(a); + } + // 4. copy timeout options + // but other solver options are not copied + SetTimeoutMs(other_.timeout_ms); + SetRLimit(other_.rlimit); + // 5. copy the scope stack, which containing comments for SMTLIB2 generation + scope_stack_ = other_.scope_stack_; + } + + /// @brief Set timeout in milliseconds + void SetTimeoutMs(unsigned timeout_ms) { + this->timeout_ms = timeout_ms; + solver.set("timeout", timeout_ms); + } + + /// @brief Set max steps + void SetRLimit(unsigned rlimit) { + this->rlimit = rlimit; + solver.set("rlimit", rlimit); + } + + /// @brief Get the SMTLIB2 representation of the current solver state + ffi::String GetSMTLIB2() { + std::stringstream ss; + ss << "(set-option :timeout " << timeout_ms << ")\n"; + AddScopeDebugMsg(ss); + ss << solver.to_smt2(); + return ss.str(); + } + + void AddScopeDebugMsg(std::ostream& ss) { + for (const auto& scope : scope_stack_) { + ss << "; Entering Scope\n"; + for (const auto& s : scope) { + switch (s.kind) { + case Scope::Constraint: + ss << "; constraint: " << s.constraint << "\n"; + break; + case Scope::BindValue: + ss << "; bind value: " << s.var << " = " << s.value << "\n"; + break; + case Scope::BindRange: + ss << "; bind range: " << s.var << " in [" << s.min << ", " << s.min + s.extent + << ")\n"; + break; + } + } + } + } + + /// @brief Get the SMTLIB2 representation of the current solver state with additional expr trying + /// to prove + ffi::String GetSMTLIB2(const PrimExpr& expr) { + std::stringstream ss; + ss << "(set-option :timeout " << timeout_ms << ")\n"; + AddScopeDebugMsg(ss); + ss << "; Trying to prove: " << expr << "\n"; + solver.push(); + solver.add(!ConvertBool(expr)); + ss << solver.to_smt2(); + solver.pop(); + return ss.str(); + } + + /// @brief Get the statistics of the solver + ffi::String GetStats() { + std::stringstream ss; + ss << solver.statistics(); + return ss.str(); + } + + ffi::String GetModel(const PrimExpr& expr) { + solver.set("model", true); + solver.push(); + solver.add(!ConvertBool(expr)); + auto result = solver.check(); + ffi::String model_str; + if (result == z3::sat) { + z3::model m = solver.get_model(); + std::map model_map; + for (unsigned i = 0; i < m.size(); i++) { + z3::func_decl d = m[i]; + model_map.emplace(d.name().str(), m.get_const_interp(d)); + } + std::stringstream ss; + for (const auto& [k, v] : model_map) { + ss << " " << k << " = " << v << "\n"; + } + model_str = ss.str(); + } + solver.pop(); + solver.set("model", false); + return model_str; + } + + /*! + * \brief Count the number of distinct integer values satisfying current constraints. + * + * Uses Z3's model enumeration (AllSAT pattern) to count solutions: + * 1. Find a satisfying assignment + * 2. Add a blocking clause to exclude it + * 3. Repeat until UNSAT + * + * \param var The variable to count values for + * \param max_count Safety limit on enumeration + * \param min_consecutive Minimum consecutive count requirement (0 to disable) + * \return Number of satisfying values, -1 on error, -2 if min_consecutive constraint not met + */ + int64_t CountSatisfyingValues(const Var& var, int64_t max_count, int64_t min_consecutive = 1) { + if (!IsValidDType(var->dtype)) { + return -1; + } + + solver.set("model", true); + solver.push(); + + // Convert the TVM variable to Z3 expression + z3::expr z3_var = VisitInt(var); + + int64_t count = 0; + std::vector found_values; + + while (count < max_count) { + auto result = solver.check(); + if (result != z3::sat) { + break; // No more solutions + } + + z3::model m = solver.get_model(); + z3::expr val_expr = m.eval(z3_var, true); + + // Extract the integer value from Z3 expression + int64_t val; + if (val_expr.is_numeral()) { + val = val_expr.get_numeral_int64(); + } else { + // If we can't get a concrete value, stop enumeration + break; + } + + found_values.push_back(val); + count++; + + // Add blocking clause: var != val (exclude this solution) + solver.add(z3_var != ctx->int_val(val)); + } + + solver.pop(); + solver.set("model", false); + + // Clear any side effects from visiting the variable + for (const auto& expr : side_effect_exprs_) { + memo_.erase(expr); + } + side_effect_exprs_.clear(); + + // Check minimum consecutive constraint if enabled + if (min_consecutive > 0 && count > 0) { + // Sort the values to check consecutive groups + std::sort(found_values.begin(), found_values.end()); + + // Check that all values form groups of at least min_consecutive consecutive numbers + int64_t consecutive_count = 1; + for (size_t i = 1; i < found_values.size(); i++) { + if (found_values[i] == found_values[i - 1] + 1) { + // Consecutive value + consecutive_count++; + } else { + // Gap found, check if the previous group meets the minimum + if (consecutive_count < min_consecutive) { + return -2; // Previous group too small + } + consecutive_count = 1; // Start new group + } + } + // Check the last group + if (consecutive_count < min_consecutive) { + return -2; // Last group too small + } + } + + return count; + } + + private: + using Z3BinOp = z3::expr (*)(const z3::expr&, const z3::expr&); + + std::vector side_effect_exprs_; + + z3::expr ConvertBool(const PrimExpr& e, bool is_assume = false) { + this->is_assume = is_assume; + auto res = VisitBool(e); + for (auto& expr : side_effect_exprs_) { + memo_.erase(expr); + } + side_effect_exprs_.clear(); + this->is_assume = false; + return res; + } + + z3::expr ConvertInt(const PrimExpr& e, bool is_assume = false) { + this->is_assume = is_assume; + auto res = VisitInt(e); + for (auto& expr : side_effect_exprs_) { + memo_.erase(expr); + } + side_effect_exprs_.clear(); + this->is_assume = false; + return res; + } + + /// @brief Visit expression with memoization + z3::expr VisitExpr(const PrimExpr& e) override { + if (memo_.count(e)) { + return memo_.at(e); + } + auto res = Base::VisitExpr(e); + auto side_effect = SideEffect(e); + if (side_effect <= CallEffectKind::kPure) { + memo_.emplace(e, res); + } else if (side_effect <= CallEffectKind::kReadState) { + memo_.emplace(e, res); + side_effect_exprs_.emplace_back(e); + } else { + if (is_assume) { + memo_.emplace(e, res); + } + side_effect_exprs_.emplace_back(e); + } + return res; + } + + /// @brief Check if the expression is a free node having no constraints + bool IsFreeNode(const PrimExpr& e) { + if (memo_.count(e)) { + return false; + } + return e->IsInstance() || e->IsInstance() || + e->IsInstance() || e->IsInstance() || + (e->IsInstance() && !IsValidDType(Downcast(e)->value->dtype)); + } + + /// @brief Check if the dtype is valid for z3 integer operations + static bool IsValidDType(const DataType& dtype) { + return (dtype.is_int() || dtype.is_uint() || dtype.is_bool()) && dtype.lanes() == 1; + } + + /// @brief Visit the expression and convert it into z3 integer expression + z3::expr VisitInt(const PrimExpr& expr) { + auto e = VisitExpr(expr); + if (e.is_bool()) { + return z3::ite(e, ctx->int_val(1), ctx->int_val(0)); + } else { + return e; + } + } + + /// @brief Visit the expression and convert it into z3 boolean expression + z3::expr VisitBool(const PrimExpr& e) { + auto expr = VisitExpr(e); + if (expr.is_bool()) { + return expr; + } else { + return expr != ctx->int_val(0); + } + } + + /// @brief Helper function to visit binary arithmetic operations + z3::expr VisitArith(Z3BinOp signed_op, const PrimExprNode* op, const PrimExpr& a, + const PrimExpr& b) { + if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + return signed_op(VisitInt(a), VisitInt(b)); + } else { + return Create(op); + } + } + + z3::expr VisitExpr_(const LetNode* op) override { + if (IsValidDType(op->var->dtype)) { + memo_.emplace(op->var, VisitInt(op->value)); + } + return VisitExpr(op->body); + } + z3::expr VisitExpr_(const CastNode* op) override { + // if the inner dtype is valid, we just visit it + if (IsValidDType(op->value->dtype) && IsValidDType(op->dtype)) { + return VisitInt(op->value); + } else { + // otherwise, we create a new free z3 variable + return Create(op); + } + } + z3::expr VisitExpr_(const VarNode* op) override { return Create(op); } + z3::expr VisitExpr_(const BufferLoadNode* op) override { return Create(op); } + z3::expr VisitExpr_(const ProducerLoadNode* op) override { return Create(op); } + z3::expr VisitExpr_(const ReduceNode* op) override { return Create(op); } + z3::expr VisitExpr_(const MinNode* op) override { + auto a = VisitInt(op->a); + auto b = VisitInt(op->b); + return z3::ite(a < b, a, b); + } + z3::expr VisitExpr_(const MaxNode* op) override { + auto a = VisitInt(op->a); + auto b = VisitInt(op->b); + return z3::ite(a > b, a, b); + } + static z3::expr floordiv(const z3::expr& a, const z3::expr& b) { + return z3::ite(b > 0, a / b, -((-a) / b)); + } + static z3::expr floormod(const z3::expr& a, const z3::expr& b) { + return z3::ite(b > 0, a % b, -((-a) % b)); + } + z3::expr VisitExpr_(const AddNode* op) override { + return VisitArith(z3::operator+, op, op->a, op->b); + } + z3::expr VisitExpr_(const SubNode* op) override { + return VisitArith(z3::operator-, op, op->a, op->b); + } + z3::expr VisitExpr_(const MulNode* op) override { + return VisitArith(z3::operator*, op, op->a, op->b); + } + z3::expr VisitExpr_(const DivNode* op) override { + return VisitArith(z3::operator/, op, op->a, op->b); + } + z3::expr VisitExpr_(const ModNode* op) override { + return VisitArith(z3::operator%, op, op->a, op->b); + } + z3::expr VisitExpr_(const FloorDivNode* op) override { + return VisitArith(floordiv, op, op->a, op->b); + } + z3::expr VisitExpr_(const FloorModNode* op) override { + return VisitArith(floormod, op, op->a, op->b); + } + z3::expr VisitExpr_(const EQNode* op) override { + return VisitArith(z3::operator==, op, op->a, op->b); + } + z3::expr VisitExpr_(const NENode* op) override { + return VisitArith(z3::operator!=, op, op->a, op->b); + } + z3::expr VisitExpr_(const LTNode* op) override { + return VisitArith(z3::operator<, op, op->a, op->b); + } + z3::expr VisitExpr_(const LENode* op) override { + return VisitArith(z3::operator<=, op, op->a, op->b); + } + z3::expr VisitExpr_(const GTNode* op) override { + return VisitArith(z3::operator>, op, op->a, op->b); + } + z3::expr VisitExpr_(const GENode* op) override { + return VisitArith(z3::operator>=, op, op->a, op->b); + } + z3::expr VisitExpr_(const AndNode* op) override { return VisitBool(op->a) && VisitBool(op->b); } + z3::expr VisitExpr_(const OrNode* op) override { return VisitBool(op->a) || VisitBool(op->b); } + z3::expr VisitExpr_(const NotNode* op) override { return !VisitBool(op->a); } + z3::expr VisitExpr_(const SelectNode* op) override { + return z3::ite(VisitBool(op->condition), VisitInt(op->true_value), VisitInt(op->false_value)); + } + z3::expr VisitExpr_(const IntImmNode* op) override { return ctx->int_val(op->value); } + + // Bitwise operations + z3::expr VisitExpr_(const CallNode* op) override { + // Check if this is a bitwise operation + if (op->op.same_as(tirx::builtin::bitwise_and())) { + return VisitBitwiseOp(z3::operator&, op); + } else if (op->op.same_as(tirx::builtin::bitwise_or())) { + return VisitBitwiseOp(z3::operator|, op); + } else if (op->op.same_as(tirx::builtin::bitwise_xor())) { + return VisitBitwiseOp(z3::operator^, op); + } else if (op->op.same_as(tirx::builtin::bitwise_not())) { + return VisitBitwiseNotOp(op); + } else if (op->op.same_as(tirx::builtin::shift_left())) { + return VisitShiftOp(z3::shl, op); + } else if (op->op.same_as(tirx::builtin::shift_right())) { + return VisitShiftOp(z3::ashr, op); + } else { + // For other call nodes, create a free variable + return Create(op); + } + } + + /// @brief Helper function to visit binary bitwise operations + z3::expr VisitBitwiseOp(z3::expr (*op_func)(const z3::expr&, const z3::expr&), + const CallNode* op) { + if (op->args.size() != 2) { + LOG(FATAL) << "Binary bitwise operation expects 2 arguments, got " << op->args.size(); + TVM_FFI_UNREACHABLE(); + } + + const PrimExpr& a = op->args[0]; + const PrimExpr& b = op->args[1]; + unsigned bit_width = std::max(op->args[0].dtype().bits(), op->args[1].dtype().bits()); + + if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + return z3::bv2int( + op_func(z3::int2bv(bit_width, VisitInt(a)), z3::int2bv(bit_width, VisitInt(b))), true); + } else { + return Create(op); + } + } + + /// @brief Helper function to visit unary bitwise not operation + z3::expr VisitBitwiseNotOp(const CallNode* op) { + if (op->args.size() != 1) { + LOG(FATAL) << "Bitwise not operation expects 1 argument, got " << op->args.size(); + TVM_FFI_UNREACHABLE(); + } + + const PrimExpr& a = op->args[0]; + + if (IsValidDType(a->dtype)) { + // Cast integer to bit-vector, apply bitwise not, then cast back. + unsigned bit_width = a.dtype().bits(); + z3::expr a_int = VisitInt(a); + z3::expr a_bv = z3::int2bv(bit_width, a_int); + return z3::bv2int(~a_bv, true); + } else { + return Create(op); + } + } + + /// @brief Helper function to visit shift operations + z3::expr VisitShiftOp(z3::expr (*op_func)(const z3::expr&, const z3::expr&), const CallNode* op) { + if (op->args.size() != 2) { + LOG(FATAL) << "Shift operation expects 2 arguments, got " << op->args.size(); + TVM_FFI_UNREACHABLE(); + } + + const PrimExpr& a = op->args[0]; + const PrimExpr& b = op->args[1]; + + // Shift operations require integer types for both operands + if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + // For shift operations, we need to ensure the shift amount is non-negative + // and within reasonable bounds + z3::expr a_expr = VisitInt(a); + z3::expr b_expr = VisitInt(b); + + // Add constraint that shift amount should be non-negative + // This is a common assumption in many programming languages + solver.add(b_expr >= 0); + + // Also limit shift amount to avoid unrealistic large shifts + // We'll limit to 64 bits (reasonable for most use cases) + solver.add(b_expr < 64); + + unsigned bit_width = std::max(a.dtype().bits(), b.dtype().bits()); + z3::expr a_bv = z3::int2bv(bit_width, a_expr); + z3::expr b_bv = z3::int2bv(bit_width, b_expr); + + // Perform the shift in bit-vector domain, then cast back to int. + z3::expr result_bv = op_func(a_bv, b_bv); + return z3::bv2int(result_bv, true); + } else { + return Create(op); + } + } + + z3::expr VisitExprDefault_(const Object* op) override { + LOG(FATAL) << "Z3Prover only support integers, but got " << op->GetTypeKey() << "."; + TVM_FFI_UNREACHABLE(); + } +}; + +TVM_DLL bool Z3Prover::CanProve(const PrimExpr& expr) { return impl_->CanProve(expr); } +TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool allow_override) { + return impl_->Bind(var, new_range, allow_override); +} +TVM_DLL void Z3Prover::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { + return impl_->Bind(var, expr, allow_override); +} +std::function Z3Prover::EnterConstraint(const PrimExpr& constraint, bool is_assume) { + return impl_->EnterConstraint(constraint, is_assume); +} +ffi::String Z3Prover::GetSMTLIB2(const ffi::Optional expr) { + if (expr.has_value()) { + return impl_->GetSMTLIB2(expr.value()); + } else { + return impl_->GetSMTLIB2(); + } +} +void Z3Prover::SetTimeoutMs(unsigned timeout_ms) { impl_->SetTimeoutMs(timeout_ms); } +void Z3Prover::SetRLimit(unsigned max_step) { impl_->SetRLimit(max_step); } +void Z3Prover::CopyFrom(const Z3Prover& other) { impl_->CopyFrom(*other.impl_); } +ffi::String Z3Prover::GetStats() { return impl_->GetStats(); } +ffi::String Z3Prover::GetModel(const PrimExpr& expr) { return impl_->GetModel(expr); } +TVM_DLL int64_t Z3Prover::CountSatisfyingValues(const Var& var, int64_t max_count, + int64_t min_consecutive) { + return impl_->CountSatisfyingValues(var, max_count, min_consecutive); +} +Z3Prover::Z3Prover(AnalyzerObj* parent) : impl_(new Impl{parent}) {} +TVM_DLL Z3Prover::~Z3Prover() { delete impl_; } + +} // namespace tvm::arith diff --git a/tests/python/arith/test_arith_z3.py b/tests/python/arith/test_arith_z3.py new file mode 100644 index 000000000000..c638341f4cd1 --- /dev/null +++ b/tests/python/arith/test_arith_z3.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm +import tvm.testing +from tvm import tirx + + +def _require_z3(analyzer): + if "Z3 Prover is disabled" in analyzer.get_smtlib2(): + pytest.skip("Z3 prover is disabled in this build") + + +def test_z3_disabled_api_is_available(): + analyzer = tvm.arith.Analyzer() + assert isinstance(analyzer.get_smtlib2(), str) + assert isinstance(analyzer.get_z3_stats(), str) + + +def test_z3_proves_floor_division_identity(): + analyzer = tvm.arith.Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + + with analyzer.constraint_scope(tirx.all(a > 0, b > 0, c > 0)): + expr = ((b - a) // c) * c + a <= b + assert analyzer.can_prove(expr) + + +def test_z3_bind_range(): + analyzer = tvm.arith.Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + + analyzer.bind(a, tvm.ir.Range(1, 100000)) + analyzer.bind(b, tvm.ir.Range(1, 100000)) + analyzer.bind(c, tvm.ir.Range(1, 100000)) + + expr = ((b - a) // c) * c + a <= b + assert analyzer.can_prove(expr) + + +def test_z3_smtlib2_roundtrip(): + z3 = pytest.importorskip("z3") + analyzer = tvm.arith.Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + expr = ((b - a) // c) * c + a <= b + + solver = z3.Solver() + with analyzer.constraint_scope(tirx.all(a > 0, b > 0, c > 0)): + solver.from_string(analyzer.get_smtlib2(expr)) + assert solver.check() == z3.unsat + + +def test_z3_bitwise(): + analyzer = tvm.arith.Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + analyzer.bind(x, tvm.ir.Range(0, 256)) + + assert analyzer.can_prove(tirx.bitwise_and(x, tirx.IntImm("int32", 7)) < 8) + + +if __name__ == "__main__": + tvm.testing.main() From 644c8c73b1c27d95a9e28b8c750d9ba85ff4448b Mon Sep 17 00:00:00 2001 From: Ubospica Date: Wed, 10 Jun 2026 16:48:29 -0400 Subject: [PATCH 06/21] update Signed-off-by: Ubospica --- 3rdparty/cnpy | 1 + 3rdparty/dmlc-core | 1 + 3rdparty/rang | 1 + 3rdparty/tvm-ffi | 2 +- 3rdparty/zlib | 1 + cmake/modules/contrib/Z3.cmake | 8 +- include/tvm/arith/analyzer.h | 12 +- python/tvm/_version.py | 24 + python/tvm/arith/analyzer.py | 51 +- src/arith/analyzer.cc | 15 +- src/arith/rewrite_simplify.cc | 10 +- src/arith/rewrite_simplify.h | 2 +- .../z3/z3_prover_on.cc => arith/z3_prover.cc} | 215 ++++-- src/target/z3/z3_prover_off.cc | 38 - tests/python/arith/test_arith_z3.py | 667 +++++++++++++++++- 15 files changed, 901 insertions(+), 147 deletions(-) create mode 160000 3rdparty/cnpy create mode 160000 3rdparty/dmlc-core create mode 160000 3rdparty/rang create mode 160000 3rdparty/zlib create mode 100644 python/tvm/_version.py rename src/{target/z3/z3_prover_on.cc => arith/z3_prover.cc} (78%) delete mode 100644 src/target/z3/z3_prover_off.cc diff --git a/3rdparty/cnpy b/3rdparty/cnpy new file mode 160000 index 000000000000..4e8810b1a863 --- /dev/null +++ b/3rdparty/cnpy @@ -0,0 +1 @@ +Subproject commit 4e8810b1a8637695171ed346ce68f6984e585ef4 diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core new file mode 160000 index 000000000000..3031e4a61a98 --- /dev/null +++ b/3rdparty/dmlc-core @@ -0,0 +1 @@ +Subproject commit 3031e4a61a98f49f07a42cfdec6242340fb2fd8c diff --git a/3rdparty/rang b/3rdparty/rang new file mode 160000 index 000000000000..cabe04d6d6b0 --- /dev/null +++ b/3rdparty/rang @@ -0,0 +1 @@ +Subproject commit cabe04d6d6b05356fa8f9741704924788f0dd762 diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi index 59da4c0b82af..98d0029dd4e0 160000 --- a/3rdparty/tvm-ffi +++ b/3rdparty/tvm-ffi @@ -1 +1 @@ -Subproject commit 59da4c0b82af0d499dae34bd89ef010f64d3ff45 +Subproject commit 98d0029dd4e002da1516d43f9b92e792f139e709 diff --git a/3rdparty/zlib b/3rdparty/zlib new file mode 160000 index 000000000000..ef24c4c75021 --- /dev/null +++ b/3rdparty/zlib @@ -0,0 +1 @@ +Subproject commit ef24c4c7502169f016dcd2a26923dbaf3216748c diff --git a/cmake/modules/contrib/Z3.cmake b/cmake/modules/contrib/Z3.cmake index eef62e4cfcd8..7d0b746aa36f 100644 --- a/cmake/modules/contrib/Z3.cmake +++ b/cmake/modules/contrib/Z3.cmake @@ -15,8 +15,10 @@ # specific language governing permissions and limitations # under the License. +# src/arith/z3_prover.cc is always part of COMPILER_SRCS (picked up by the +# src/arith/*.cc glob). It compiles a conservative stub by default and switches +# to the real Z3 implementation only when the TVM_USE_Z3 macro is defined below. if(NOT USE_Z3) - list(APPEND COMPILER_SRCS src/target/z3/z3_prover_off.cc) return() endif() @@ -73,4 +75,6 @@ else() message(FATAL_ERROR "USE_Z3 is ON, but Z3 was not found. Install Z3 or PyPI z3-solver.") endif() -list(APPEND COMPILER_SRCS src/target/z3/z3_prover_on.cc) +# Enable the real Z3 implementation inside the single src/arith/z3_prover.cc file. +add_compile_definitions(TVM_USE_Z3) +message(STATUS "Build with Z3 SMT solver support") diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 7de9213f8114..cbe9051e3b73 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -300,7 +300,7 @@ class RewriteSimplifier { * * \return an exit function that must be called to cleanup the constraint can be nullptr. */ - TVM_DLL std::function EnterConstraint(const PrimExpr& constraint, bool is_assume = false); + TVM_DLL std::function EnterConstraint(const PrimExpr& constraint); /*! \brief Flags to enable more computationally-intensive simplifications * @@ -609,6 +609,13 @@ class Z3Prover { */ TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false); + /*! + * \brief Whether the Z3 backend is compiled into this build (USE_Z3=ON). + * + * \return true if the real Z3 prover is available, false for the stub. + */ + TVM_DLL bool IsEnabled() const; + /*! * \brief Whether can we prove expr is always true. * @@ -621,10 +628,9 @@ class Z3Prover { * \brief Update the internal state to enter constraint. * * \param constraint A constraint expression. - * \param is_assume Whether the constraint comes from an assumption. * \return an exit function that must be called to cleanup the constraint can be nullptr. */ - std::function EnterConstraint(const PrimExpr& constraint, bool is_assume = false); + std::function EnterConstraint(const PrimExpr& constraint); /*! * \brief Get the SMTLIB2 representation of the current context. diff --git a/python/tvm/_version.py b/python/tvm/_version.py new file mode 100644 index 000000000000..618c82e5a4ed --- /dev/null +++ b/python/tvm/_version.py @@ -0,0 +1,24 @@ +# file generated by vcs-versioning +# don't change, don't track in version control +from __future__ import annotations + +__all__ = [ + "__version__", + "__version_tuple__", + "version", + "version_tuple", + "__commit_id__", + "commit_id", +] + +version: str +__version__: str +__version_tuple__: tuple[int | str, ...] +version_tuple: tuple[int | str, ...] +commit_id: str | None +__commit_id__: str | None + +__version__ = version = '0.25.dev100' +__version_tuple__ = version_tuple = (0, 25, 'dev100') + +__commit_id__ = commit_id = 'g35152f312' diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 39fa30d64180..78e93395c382 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -128,44 +128,89 @@ class Analyzer(Object): def __init__(self): self.__init_handle_by_constructor__(_ffi_api.Analyzer) - def get_smtlib2(self, expr: tirx.PrimExpr = None) -> str: + @property + def is_z3_enabled(self) -> bool: + """Whether this build includes the Z3 backend (``USE_Z3=ON``). + + The Z3-specific methods (:py:meth:`get_smtlib2`, :py:meth:`get_z3_stats`, + :py:meth:`set_z3_timeout_ms`, :py:meth:`set_z3_rlimit`) only work when + this is ``True``. + """ + return bool(_ffi_api.AnalyzerIsZ3Enabled(self)) + + def _check_z3_enabled(self) -> None: + if not self.is_z3_enabled: + raise RuntimeError( + "The Z3 backend is not available in this build. " + "Rebuild TVM with USE_Z3=ON to use Z3-specific Analyzer APIs." + ) + + def get_smtlib2(self, expr: tirx.PrimExpr | None = None) -> str: """Get the current Z3 problem in SMT-LIB2 format. + Raises + ------ + RuntimeError + If TVM was built without Z3 (``USE_Z3=OFF``), since there is no + solver state to export. Use :py:attr:`is_z3_enabled` to check first. + Parameters ---------- expr : Optional[PrimExpr] The expression to prove. If provided, its negation is added to the problem. """ + self._check_z3_enabled() return _ffi_api.AnalyzerGetSMTLIB2(self, expr) def set_z3_timeout_ms(self, timeout_ms: int) -> None: """Set Z3 timeout in milliseconds. + Raises + ------ + RuntimeError + If TVM was built without Z3 (``USE_Z3=OFF``). + Parameters ---------- timeout_ms : int The timeout in milliseconds. """ + self._check_z3_enabled() _ffi_api.AnalyzerSetZ3TimeoutMs(self, timeout_ms) def set_z3_rlimit(self, rlimit: int) -> None: """Set Z3 resource limit. + The resource limit gives deterministic solver budgeting (unlike a wall + clock timeout). A value of ``0`` disables the limit. + + Raises + ------ + RuntimeError + If TVM was built without Z3 (``USE_Z3=OFF``). + Parameters ---------- rlimit : int The resource limit. """ + self._check_z3_enabled() _ffi_api.AnalyzerSetZ3RLimit(self, rlimit) def get_z3_stats(self) -> str: """Get Z3 solver statistics. + Raises + ------ + RuntimeError + If TVM was built without Z3 (``USE_Z3=OFF``). + Returns ------- stats : str The Z3 statistics. """ + self._check_z3_enabled() return _ffi_api.AnalyzerGetZ3Stats(self) def const_int_bound(self, expr: tirx.PrimExpr) -> ConstIntBound: @@ -300,7 +345,9 @@ def can_prove( The expression. strength: ProofStrength - The proof strength + The proof strength. When TVM is built with Z3 (``USE_Z3=ON``), the + optional Z3 fallback is only consulted at ``SYMBOLIC_BOUND`` or + higher, after the native analyzers fail to prove the predicate. Returns ------- diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index cb209b0a9dfb..8d2807dfcbb9 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -131,11 +131,10 @@ void ConstraintContext::EnterWithScope() { // entering the scope. recovery_functions_.push_back(analyzer_->const_int_bound.EnterConstraint(constraint_)); recovery_functions_.push_back(analyzer_->modular_set.EnterConstraint(constraint_)); - recovery_functions_.push_back( - analyzer_->rewrite_simplify.EnterConstraint(constraint_, is_assume_)); + recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_)); recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_)); recovery_functions_.push_back(analyzer_->transitive_comparisons.EnterConstraint(constraint_)); - recovery_functions_.push_back(analyzer_->z3_prover.EnterConstraint(constraint_, is_assume_)); + recovery_functions_.push_back(analyzer_->z3_prover.EnterConstraint(constraint_)); } void ConstraintContext::ExitWithScope() { @@ -236,9 +235,11 @@ bool AnalyzerObj::CanProve(const PrimExpr& expr, ProofStrength strength) { } } - if (z3_prover.CanProve(simplified)) { - return true; - } + // Z3 is an expensive best-effort fallback. Gate it behind the higher + // kSymbolicBound strength so the common kDefault path (including deeply + // recursive internal CanProve calls) never pays the prover cost. + if (strength >= ProofStrength::kSymbolicBound && z3_prover.CanProve(simplified)) { + return true; } return false; } @@ -343,6 +344,8 @@ TVM_FFI_STATIC_INIT_BLOCK() { return static_cast( analyzer->transitive_comparisons.TryCompare(lhs, rhs, propagate_inequalities)); }) + .def("arith.AnalyzerIsZ3Enabled", + [](Analyzer analyzer) { return analyzer->z3_prover.IsEnabled(); }) .def("arith.AnalyzerGetSMTLIB2", [](Analyzer analyzer, ffi::Optional expr) { return analyzer->z3_prover.GetSMTLIB2(expr); diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index d0cffbf7950b..2120aaa1a859 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -526,14 +526,13 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode* op) { return ret; } -std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint, - bool is_assume) { +std::function RewriteSimplifier::Impl::EnterConstraint(const PrimExpr& constraint) { size_t old_literal_size = literal_constraints_.size(); // we will compare the already simplified result with the constraint, // so simplify the constraint as well PrimExpr new_constraint = operator()(constraint); for (const PrimExpr& subconstraint : ExtractConstraints(new_constraint, false)) { - if (is_assume || SideEffect(subconstraint) <= CallEffectKind::kPure) { + if (SideEffect(subconstraint) <= CallEffectKind::kPure) { literal_constraints_.push_back(subconstraint); PrimExpr negation; if (subconstraint.dtype().is_bool()) { @@ -2441,9 +2440,8 @@ void RewriteSimplifier::Update(const Var& var, const PrimExpr& info, bool allow_ impl_->Update(var, info, allow_override); } -std::function RewriteSimplifier::EnterConstraint(const PrimExpr& constraint, - bool is_assume) { - return impl_->EnterConstraint(constraint, is_assume); +std::function RewriteSimplifier::EnterConstraint(const PrimExpr& constraint) { + return impl_->EnterConstraint(constraint); } void RewriteSimplifier::SetEnabledExtensions(Extension flags) { diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h index 026deee72bec..b42b73336a27 100644 --- a/src/arith/rewrite_simplify.h +++ b/src/arith/rewrite_simplify.h @@ -117,7 +117,7 @@ class RewriteSimplifier::Impl : public IRMutatorWithAnalyzer { PrimExpr VisitExpr_(const CastNode* op) override; PrimExpr VisitExpr_(const LetNode* op) override; - std::function EnterConstraint(const PrimExpr& constraint, bool is_assume = false); + std::function EnterConstraint(const PrimExpr& constraint); /*! \brief Enable an optional extension or extensions * diff --git a/src/target/z3/z3_prover_on.cc b/src/arith/z3_prover.cc similarity index 78% rename from src/target/z3/z3_prover_on.cc rename to src/arith/z3_prover.cc index 08b1c359605a..1a368c83904b 100644 --- a/src/target/z3/z3_prover_on.cc +++ b/src/arith/z3_prover.cc @@ -1,3 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/arith/z3_prover.cc + * \brief Optional Z3 SMT solver backend for arith::Analyzer. + * + * The real implementation is compiled only when TVM_USE_Z3 is defined (set by + * the USE_Z3 CMake option). Otherwise a conservative stub is compiled so the + * C++ and Python APIs stay available without a Z3 dependency. + */ +#ifdef TVM_USE_Z3 + #include #include #include @@ -20,7 +49,6 @@ #include #include -#include "tvm/arith/analyzer.h" #include "tvm/ffi/cast.h" #include "tvm/ffi/object.h" #include "tvm/ffi/string.h" @@ -87,8 +115,6 @@ class Z3Prover::Impl : ExprFunctor { /// @brief Memorize pure expressions std::unordered_map memo_; - bool is_assume = false; - /// @brief Namespace for variable naming Namespace ns; @@ -111,11 +137,8 @@ class Z3Prover::Impl : ExprFunctor { Impl(AnalyzerObj* parent) : analyzer(parent) { scope_stack_.push_back({}); solver = CreateSolver(*ctx); - // default timeout 5ms - // Z3's implementation of timeout, when setting timeout T ms, it will stop at T - 1 ms - // SetTimeoutMs(5); - // use rlimit, not timeout to ensure determinstic behavior - SetRLimit(1e4); + // use rlimit, not timeout to ensure deterministic behavior + SetRLimit(10000U); } /// @brief Create a Free z3 expression from PrimExprNode @@ -157,33 +180,21 @@ class Z3Prover::Impl : ExprFunctor { std::vector> scope_stack_; /// @brief Enter a constraint scope - std::function EnterConstraint(const PrimExpr& constraint, bool is_assume = false) { + std::function EnterConstraint(const PrimExpr& constraint) { scope_stack_.push_back({}); scope_stack_.back().push_back( Scope{Scope::Constraint, Var(), PrimExpr(), PrimExpr(), PrimExpr(), constraint}); solver.push(); - this->is_assume = is_assume; solver.add(VisitBool(constraint)); - this->is_assume = false; auto side_effect_exprs = std::move(side_effect_exprs_); side_effect_exprs_.clear(); - if (is_assume) { - return [this, side_effect_exprs]() { - solver.pop(); - for (const auto& expr : side_effect_exprs) { - memo_.erase(expr); - } - scope_stack_.pop_back(); - }; - } else { - for (const auto& expr : side_effect_exprs) { - memo_.erase(expr); - } - return [this]() { - solver.pop(); - scope_stack_.pop_back(); - }; + for (const auto& expr : side_effect_exprs) { + memo_.erase(expr); } + return [this]() { + solver.pop(); + scope_stack_.pop_back(); + }; } /// @brief Check trivil bad cases, return true if the expr is a bad case @@ -231,13 +242,19 @@ class Z3Prover::Impl : ExprFunctor { /// @brief Check if the expression can be proved bool CanProve(const PrimExpr& expr) { - if (CheckTrivilBadCases(expr)) return false; - if (!IsValidDType(expr->dtype)) return false; - z3::expr_vector constr(*ctx); - constr.push_back(!ConvertBool(expr)); - auto result = solver.check(constr); - constr.pop_back(); - return result == z3::unsat; + // Z3 is only a fallback. Any failure (including z3::exception thrown by the + // solver) must degrade to "cannot prove" instead of escaping to the caller. + try { + if (CheckTrivilBadCases(expr)) return false; + if (!IsValidDType(expr->dtype)) return false; + z3::expr_vector constr(*ctx); + constr.push_back(!ConvertBool(expr)); + auto result = solver.check(constr); + constr.pop_back(); + return result == z3::unsat; + } catch (const z3::exception&) { + return false; + } } /// @brief Binded @@ -265,9 +282,15 @@ class Z3Prover::Impl : ExprFunctor { // when min_expr >= max_expr, the range is empty, which is under undefined behavior // instead of adding an unsat constraint, we just skip the range constraint to leave it a // free var - if (tirx::is_const_int(range->min) && tirx::is_const_int(range->min + range->extent)) { + // + // NOTE: range->min + range->extent builds a fresh AddNode that is not folded, so we must + // test is_const_int on range->min and range->extent individually and add the two constants + // in C++. Otherwise this fast path is never taken and we always emit the more expensive + // symbolic constraint below. + if (tirx::is_const_int(range->min) && tirx::is_const_int(range->extent)) { int64_t min_value = *tirx::as_const_int(range->min); - int64_t max_value = *tirx::as_const_int(range->min + range->extent); + int64_t extent_value = *tirx::as_const_int(range->extent); + int64_t max_value = min_value + extent_value; if (min_value < max_value) { solver.add(ctx->int_val(min_value) <= var_expr); solver.add(var_expr < ctx->int_val(max_value)); @@ -281,11 +304,12 @@ class Z3Prover::Impl : ExprFunctor { void CopyFrom(const Self& other_) { // 1. create a new solver // because this->solver depends on this->ctx - // we need to deconstruct the old solver, and create a new one depending on other_.ctx - solver = CreateSolver(*other_.ctx); - // 2. copy the context - // the context is a shared_ptr, we can just copy the pointer - ctx = other_.ctx; + // we need to deconstruct the old solver, and create a new one depending on this->ctx + solver = CreateSolver(*ctx); + // 2. ctx is a static thread_local pointer, so other_.ctx already refers to the same + // context on the current thread; there is nothing to copy here. Cross-thread copying + // of Z3Prover is not supported because Z3 expressions cannot be shared across different + // thread-local contexts without explicit translation. // 3. copy other objects ns = other_.ns; for (auto& item : other_.memo_) { @@ -482,25 +506,21 @@ class Z3Prover::Impl : ExprFunctor { std::vector side_effect_exprs_; - z3::expr ConvertBool(const PrimExpr& e, bool is_assume = false) { - this->is_assume = is_assume; + z3::expr ConvertBool(const PrimExpr& e) { auto res = VisitBool(e); for (auto& expr : side_effect_exprs_) { memo_.erase(expr); } side_effect_exprs_.clear(); - this->is_assume = false; return res; } - z3::expr ConvertInt(const PrimExpr& e, bool is_assume = false) { - this->is_assume = is_assume; + z3::expr ConvertInt(const PrimExpr& e) { auto res = VisitInt(e); for (auto& expr : side_effect_exprs_) { memo_.erase(expr); } side_effect_exprs_.clear(); - this->is_assume = false; return res; } @@ -517,9 +537,6 @@ class Z3Prover::Impl : ExprFunctor { memo_.emplace(e, res); side_effect_exprs_.emplace_back(e); } else { - if (is_assume) { - memo_.emplace(e, res); - } side_effect_exprs_.emplace_back(e); } return res; @@ -599,6 +616,20 @@ class Z3Prover::Impl : ExprFunctor { auto b = VisitInt(op->b); return z3::ite(a > b, a, b); } + // TVM Div/Mod are truncated (round toward zero), while Z3's native operator/ + // and operator% are Euclidean. Using the raw operators is unsound once the + // dividend can be negative, so we implement truncating helpers explicitly. + static z3::expr truncdiv(const z3::expr& a, const z3::expr& b) { + z3::expr abs_a = z3::ite(a >= 0, a, -a); + z3::expr abs_b = z3::ite(b >= 0, b, -b); + // |a| / |b| is exact (Euclidean == truncated for non-negative operands). + z3::expr q = abs_a / abs_b; + return z3::ite((a >= 0) == (b >= 0), q, -q); + } + static z3::expr truncmod(const z3::expr& a, const z3::expr& b) { + // TVM Mod follows the sign of the dividend: a - b * truncdiv(a, b). + return a - b * truncdiv(a, b); + } static z3::expr floordiv(const z3::expr& a, const z3::expr& b) { return z3::ite(b > 0, a / b, -((-a) / b)); } @@ -614,12 +645,8 @@ class Z3Prover::Impl : ExprFunctor { z3::expr VisitExpr_(const MulNode* op) override { return VisitArith(z3::operator*, op, op->a, op->b); } - z3::expr VisitExpr_(const DivNode* op) override { - return VisitArith(z3::operator/, op, op->a, op->b); - } - z3::expr VisitExpr_(const ModNode* op) override { - return VisitArith(z3::operator%, op, op->a, op->b); - } + z3::expr VisitExpr_(const DivNode* op) override { return VisitArith(truncdiv, op, op->a, op->b); } + z3::expr VisitExpr_(const ModNode* op) override { return VisitArith(truncmod, op, op->a, op->b); } z3::expr VisitExpr_(const FloorDivNode* op) override { return VisitArith(floordiv, op, op->a, op->b); } @@ -667,6 +694,10 @@ class Z3Prover::Impl : ExprFunctor { return VisitShiftOp(z3::shl, op); } else if (op->op.same_as(tirx::builtin::shift_right())) { return VisitShiftOp(z3::ashr, op); + } else if (op->op.same_as(tirx::builtin::if_then_else()) && op->args.size() == 3 && + IsValidDType(op->args[1]->dtype) && IsValidDType(op->args[2]->dtype)) { + // tir.if_then_else(cond, a, b) is a select-like ternary. + return z3::ite(VisitBool(op->args[0]), VisitInt(op->args[1]), VisitInt(op->args[2])); } else { // For other call nodes, create a free variable return Create(op); @@ -725,19 +756,14 @@ class Z3Prover::Impl : ExprFunctor { // Shift operations require integer types for both operands if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { - // For shift operations, we need to ensure the shift amount is non-negative - // and within reasonable bounds z3::expr a_expr = VisitInt(a); z3::expr b_expr = VisitInt(b); - // Add constraint that shift amount should be non-negative - // This is a common assumption in many programming languages - solver.add(b_expr >= 0); - - // Also limit shift amount to avoid unrealistic large shifts - // We'll limit to 64 bits (reasonable for most use cases) - solver.add(b_expr < 64); - + // Rely on Z3's native bit-vector shift behavior. We must NOT add hard + // assertions such as `b_expr >= 0` to the solver here: solver.add() has no + // matching push/pop in this path, so the assertion would permanently + // poison the shared solver and make all subsequent unrelated proofs about + // `b` unsound. unsigned bit_width = std::max(a.dtype().bits(), b.dtype().bits()); z3::expr a_bv = z3::int2bv(bit_width, a_expr); z3::expr b_bv = z3::int2bv(bit_width, b_expr); @@ -751,11 +777,15 @@ class Z3Prover::Impl : ExprFunctor { } z3::expr VisitExprDefault_(const Object* op) override { - LOG(FATAL) << "Z3Prover only support integers, but got " << op->GetTypeKey() << "."; - TVM_FFI_UNREACHABLE(); + // Z3 is a best-effort fallback that runs only after the native analyzers + // have already failed. An unsupported node must not crash the build, so we + // model it as a fresh unconstrained free variable, which keeps the proof + // sound (it can only make CanProve more conservative). + return Create(static_cast(op)); } }; +TVM_DLL bool Z3Prover::IsEnabled() const { return true; } TVM_DLL bool Z3Prover::CanProve(const PrimExpr& expr) { return impl_->CanProve(expr); } TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool allow_override) { return impl_->Bind(var, new_range, allow_override); @@ -763,8 +793,8 @@ TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool allow_o TVM_DLL void Z3Prover::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { return impl_->Bind(var, expr, allow_override); } -std::function Z3Prover::EnterConstraint(const PrimExpr& constraint, bool is_assume) { - return impl_->EnterConstraint(constraint, is_assume); +std::function Z3Prover::EnterConstraint(const PrimExpr& constraint) { + return impl_->EnterConstraint(constraint); } ffi::String Z3Prover::GetSMTLIB2(const ffi::Optional expr) { if (expr.has_value()) { @@ -786,3 +816,48 @@ Z3Prover::Z3Prover(AnalyzerObj* parent) : impl_(new Impl{parent}) {} TVM_DLL Z3Prover::~Z3Prover() { delete impl_; } } // namespace tvm::arith + +#else // TVM_USE_Z3 + +#include +#include +#include + +#include "tvm/ffi/string.h" +#include "tvm/ir/expr.h" + +namespace tvm::arith { + +using namespace tirx; +using namespace ffi; + +// Stub implementation used when Z3 support is not built. All proving queries +// conservatively report "cannot prove" while keeping the public API available. +class Z3Prover::Impl {}; + +TVM_DLL bool Z3Prover::IsEnabled() const { return false; } +TVM_DLL bool Z3Prover::CanProve(const PrimExpr& expr) { return false; } +TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool allow_override) {} +TVM_DLL void Z3Prover::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {} +std::function Z3Prover::EnterConstraint(const PrimExpr& constraint) { + return []() {}; +} +ffi::String Z3Prover::GetSMTLIB2(const ffi::Optional expr) { + return "; Z3 Prover is disabled."; +} +void Z3Prover::SetTimeoutMs(unsigned timeout_ms) {} +void Z3Prover::SetRLimit(unsigned rlimit) {} +ffi::String Z3Prover::GetModel(const PrimExpr& expr) { return "; Z3 Prover is disabled."; } +TVM_DLL int64_t Z3Prover::CountSatisfyingValues(const Var& var, int64_t max_count, + int64_t min_consecutive) { + return -1; // Z3 disabled, return error +} + +void Z3Prover::CopyFrom(const Z3Prover& other) {} +ffi::String Z3Prover::GetStats() { return "; Z3 Prover is disabled."; } +Z3Prover::Z3Prover(AnalyzerObj*) : impl_(nullptr) {} +TVM_DLL Z3Prover::~Z3Prover() {} + +} // namespace tvm::arith + +#endif // TVM_USE_Z3 diff --git a/src/target/z3/z3_prover_off.cc b/src/target/z3/z3_prover_off.cc deleted file mode 100644 index 98278ffbf39d..000000000000 --- a/src/target/z3/z3_prover_off.cc +++ /dev/null @@ -1,38 +0,0 @@ -#include -#include -#include - -#include "tvm/arith/analyzer.h" -#include "tvm/ffi/string.h" -#include "tvm/ir/expr.h" - -namespace tvm::arith { - -using namespace tirx; -using namespace ffi; - -class Z3Prover::Impl {}; - -TVM_DLL bool Z3Prover::CanProve(const PrimExpr& expr) { return false; } -TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool allow_override) {} -TVM_DLL void Z3Prover::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {} -std::function Z3Prover::EnterConstraint(const PrimExpr& constraint, bool is_assume) { - return []() {}; -} -ffi::String Z3Prover::GetSMTLIB2(const ffi::Optional expr) { - return "; Z3 Prover is disabled."; -} -void Z3Prover::SetTimeoutMs(unsigned timeout_ms) {} -void Z3Prover::SetRLimit(unsigned rlimit) {} -ffi::String Z3Prover::GetModel(const PrimExpr& expr) { return "; Z3 Prover is disabled."; } -TVM_DLL int64_t Z3Prover::CountSatisfyingValues(const Var& var, int64_t max_count, - int64_t min_consecutive) { - return -1; // Z3 disabled, return error -} - -void Z3Prover::CopyFrom(const Z3Prover& other) {} -ffi::String Z3Prover::GetStats() { return "; Z3 Prover is disabled."; } -Z3Prover::Z3Prover(AnalyzerObj*) : impl_(nullptr) {} -TVM_DLL Z3Prover::~Z3Prover() {} - -} // namespace tvm::arith diff --git a/tests/python/arith/test_arith_z3.py b/tests/python/arith/test_arith_z3.py index c638341f4cd1..f9dff4cac9a5 100644 --- a/tests/python/arith/test_arith_z3.py +++ b/tests/python/arith/test_arith_z3.py @@ -20,34 +20,73 @@ import tvm import tvm.testing from tvm import tirx +from tvm.arith import Analyzer, ProofStrength + +# The Z3 prover is only consulted at the kSymbolicBound strength so the common +# default path never pays the prover cost. +SB = ProofStrength.SYMBOLIC_BOUND def _require_z3(analyzer): - if "Z3 Prover is disabled" in analyzer.get_smtlib2(): + if not analyzer.is_z3_enabled: pytest.skip("Z3 prover is disabled in this build") -def test_z3_disabled_api_is_available(): - analyzer = tvm.arith.Analyzer() - assert isinstance(analyzer.get_smtlib2(), str) - assert isinstance(analyzer.get_z3_stats(), str) +def implies(x, y): + return tirx.Or(tirx.Not(x), y) + + +# --------------------------------------------------------------------------- +# API availability (works regardless of whether Z3 is built) +# --------------------------------------------------------------------------- + + +def test_z3_capability_query(): + # `is_z3_enabled` is the supported way to detect the build configuration. + # The Z3-specific debug/config methods work only when it is True, and raise + # a clear error otherwise. + analyzer = Analyzer() + assert isinstance(analyzer.is_z3_enabled, bool) + + if analyzer.is_z3_enabled: + assert isinstance(analyzer.get_smtlib2(), str) + assert isinstance(analyzer.get_z3_stats(), str) + else: + with pytest.raises(RuntimeError): + analyzer.get_smtlib2() + with pytest.raises(RuntimeError): + analyzer.get_z3_stats() + with pytest.raises(RuntimeError): + analyzer.set_z3_timeout_ms(1000) + with pytest.raises(RuntimeError): + analyzer.set_z3_rlimit(0) + + +# --------------------------------------------------------------------------- +# Examples the native analyzer cannot prove but Z3 can. +# +# Each case asserts both that the native analyzers (kDefault, Z3 gated off) +# fail and that Z3 (kSymbolicBound) succeeds. This demonstrates the added value +# of the Z3 backend and that it is correctly gated behind kSymbolicBound. +# --------------------------------------------------------------------------- -def test_z3_proves_floor_division_identity(): - analyzer = tvm.arith.Analyzer() +def test_z3_floor_division_identity_constraint(): + analyzer = Analyzer() _require_z3(analyzer) a = tirx.Var("a", "int32") b = tirx.Var("b", "int32") c = tirx.Var("c", "int32") + expr = ((b - a) // c) * c + a <= b with analyzer.constraint_scope(tirx.all(a > 0, b > 0, c > 0)): - expr = ((b - a) // c) * c + a <= b - assert analyzer.can_prove(expr) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) -def test_z3_bind_range(): - analyzer = tvm.arith.Analyzer() +def test_z3_floor_division_identity_via_bind_range(): + analyzer = Analyzer() _require_z3(analyzer) a = tirx.Var("a", "int32") @@ -59,12 +98,598 @@ def test_z3_bind_range(): analyzer.bind(c, tvm.ir.Range(1, 100000)) expr = ((b - a) // c) * c + a <= b - assert analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_multiplication_monotonicity(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + d = tirx.Var("d", "int32") + + expr = implies(tirx.all(a < b, b < c, a * d < b * d), b * d < c * d) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_nested_floor_division_collapse(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + expr = implies( + tirx.all(a >= 0, a < 128), + a // 128 == (a // 64 * 32 + a % 32 // 16 * 8) // 64, + ) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_deeply_nested_floor_division_identity(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + expr = implies( + tirx.all(a >= 0, a < 128), + ( + a % 16 * 64 + + a // 64 * 32 + + a % 8 // 4 * 32 + + (a % 32 // 16 + a % 2) % 2 * 8 + + 16 + - (a // 64 + a % 8 // 4) // 2 * 64 + ) + // 512 + == ( + a % 16 * 64 + + a // 64 * 32 + + a % 8 // 4 * 32 + + (a % 32 // 16 + a % 2) % 2 * 8 + - (a // 64 + a % 8 // 4) // 2 * 64 + ) + // 512, + ) + assert analyzer.can_prove(expr, SB) + + +def test_z3_min_max_sum_identity(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + expr = tirx.max(x, y) + tirx.min(x, y) == x + y + assert analyzer.can_prove(expr, SB) + + +def test_z3_select_absolute_value_nonneg(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + expr = tirx.Select(x >= 0, x, -x) >= 0 + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_transitive_inequality(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + expr = implies(tirx.all(a <= b, b <= c), a <= c) + assert analyzer.can_prove(expr, SB) + + +def test_z3_square_expansion_nonneg(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + expr = (a + b) * (a + b) >= a * a + b * b + with analyzer.constraint_scope(tirx.all(a >= 0, b >= 0)): + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_square_monotonicity(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + expr = implies(tirx.all(0 <= a, a <= b), a * a <= b * b) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_strict_multiplication(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + d = tirx.Var("d", "int32") + expr = implies(tirx.all(a < b, d > 0), a * d < b * d) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_floor_division_monotonicity(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + expr = implies(tirx.all(a <= b, c > 0), tirx.floordiv(a, c) <= tirx.floordiv(b, c)) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_floor_division_lower_bound(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + expr = implies(b > 0, tirx.floordiv(a, b) * b <= a) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_floor_modulo_range(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + expr = implies(b > 0, tirx.all(0 <= tirx.floormod(a, b), tirx.floormod(a, b) < b)) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_flattened_index_bound(): + # Classic index-flattening bound used throughout TVM: for a row index i in + # [0, m) and a column index j in [0, n), the flattened index i * n + j stays + # within [0, m * n). + analyzer = Analyzer() + _require_z3(analyzer) + + i = tirx.Var("i", "int32") + j = tirx.Var("j", "int32") + m = tirx.Var("m", "int32") + n = tirx.Var("n", "int32") + expr = tirx.all(0 <= i * n + j, i * n + j < m * n) + with analyzer.constraint_scope(tirx.all(0 <= i, i < m, 0 <= j, j < n, m > 0, n > 0)): + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_modular_combination(): + # Native modular_set tracks single-variable moduli, but combining two + # independent modular facts to reason about their sum is left to Z3. + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + expr = tirx.floormod(x + y, 2) == 0 + with analyzer.constraint_scope(tirx.all(tirx.floormod(x, 6) == 0, tirx.floormod(y, 6) == 0)): + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_square_non_negative(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + assert not analyzer.can_prove(a * a >= 0) + assert analyzer.can_prove(a * a >= 0, SB) + + +def test_z3_min_max_average_bounds(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + assert not analyzer.can_prove(tirx.max(a, b) * 2 >= a + b) + assert analyzer.can_prove(tirx.max(a, b) * 2 >= a + b, SB) + assert analyzer.can_prove(tirx.min(a, b) * 2 <= a + b, SB) + + +def test_z3_symbolic_bind_range_with_constraint(): + # Combine a symbolic range binding (x in [0, n)) with a constraint on the + # extent to derive a concrete bound on x. + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + n = tirx.Var("n", "int32") + analyzer.bind(x, tvm.ir.Range(0, n)) + with analyzer.constraint_scope(n <= 8): + assert not analyzer.can_prove(x < 8) + assert analyzer.can_prove(x < 8, SB) + + +def test_z3_equality_congruence(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + expr = implies(a == b, a * a == b * b) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_integer_strict_transitivity(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + # Over the integers, a < b and b < c implies a + 1 < c. + expr = implies(tirx.all(a < b, b < c), a + 1 < c) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_if_then_else_absolute_value(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + expr = tirx.if_then_else(x >= 0, x, -x) >= 0 + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_unsigned_non_negative(): + analyzer = Analyzer() + _require_z3(analyzer) + + u = tirx.Var("u", "uint32") + assert not analyzer.can_prove(u >= 0) + assert analyzer.can_prove(u >= 0, SB) + + +def test_z3_unsigned64_non_negative(): + # Exercises the special-cased uint64 range handling (UINT64_MAX bound). + analyzer = Analyzer() + _require_z3(analyzer) + + u = tirx.Var("u", "uint64") + assert not analyzer.can_prove(u >= 0) + assert analyzer.can_prove(u >= 0, SB) + + +def test_z3_int64_square_expansion(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + expr = (a + b) * (a + b) >= a * a + b * b + with analyzer.constraint_scope(tirx.all(a >= 0, b >= 0)): + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_boolean_variable_reasoning(): + analyzer = Analyzer() + _require_z3(analyzer) + + p = tirx.Var("p", "bool") + q = tirx.Var("q", "bool") + expr = implies(tirx.And(p, q), tirx.Or(p, q)) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_not_equal_from_strict_less(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + expr = implies(x < y, tirx.NE(x, y)) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_let_expression(): + analyzer = Analyzer() + _require_z3(analyzer) + + y = tirx.Var("y", "int32") + t = tirx.Var("t", "int32") + let = tirx.Let(t, y * 2, t) + assert not analyzer.can_prove(let == y * 2) + assert analyzer.can_prove(let == y * 2, SB) + + +def test_z3_cast_preserves_bounds(): + analyzer = Analyzer() + _require_z3(analyzer) + + s = tirx.Var("s", "int16") + widened = tirx.Cast("int32", s) + assert analyzer.can_prove(widened <= 32767, SB) + assert analyzer.can_prove(widened >= -32768, SB) + + +def test_z3_bitwise_and_mask_bound(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + analyzer.bind(x, tvm.ir.Range(0, 256)) + assert analyzer.can_prove(tirx.bitwise_and(x, tirx.IntImm("int32", 7)) < 8, SB) + + +def test_z3_bitwise_and_le_operand(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + analyzer.bind(x, tvm.ir.Range(0, 256)) + analyzer.bind(y, tvm.ir.Range(0, 256)) + # Bit-vector reasoning over two variables exceeds the default deterministic + # rlimit; lift it (0 == unlimited, still deterministic) for this proof. + analyzer.set_z3_rlimit(0) + assert analyzer.can_prove(tirx.bitwise_and(x, y) <= x, SB) + + +def test_z3_bitwise_or_ge_operand(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + analyzer.bind(x, tvm.ir.Range(0, 256)) + analyzer.bind(y, tvm.ir.Range(0, 256)) + analyzer.set_z3_rlimit(0) + assert analyzer.can_prove(tirx.bitwise_or(x, y) >= x, SB) + + +def test_z3_bitwise_xor_bound(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + analyzer.bind(x, tvm.ir.Range(0, 256)) + analyzer.bind(y, tvm.ir.Range(0, 256)) + analyzer.set_z3_rlimit(0) + assert analyzer.can_prove(tirx.bitwise_xor(x, y) < 256, SB) + + +def test_z3_bitwise_not_identity(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + analyzer.bind(x, tvm.ir.Range(0, 256)) + analyzer.set_z3_rlimit(0) + # Two's complement: ~x == -x - 1. + assert analyzer.can_prove(tirx.bitwise_not(x) == -x - 1, SB) + + +def test_z3_shift_right_halves(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + analyzer.bind(x, tvm.ir.Range(0, 256)) + analyzer.set_z3_rlimit(0) + # For non-negative x, (x >> 1) * 2 <= x. + assert analyzer.can_prove(tirx.shift_right(x, tirx.IntImm("int32", 1)) * 2 <= x, SB) + + +def test_z3_shift_left_lower_bound(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + n = tirx.Var("n", "int32") + # Keep operands small so the 32-bit left shift cannot overflow; then + # x << n == x * 2 ** n >= x for x >= 1. + analyzer.bind(x, tvm.ir.Range(1, 16)) + analyzer.bind(n, tvm.ir.Range(0, 4)) + # Bit-vector shift reasoning exceeds the default deterministic rlimit. + analyzer.set_z3_rlimit(0) + assert analyzer.can_prove(tirx.shift_left(x, n) >= x, SB) + + +# --------------------------------------------------------------------------- +# Soundness / negative tests (Z3 must NOT prove false predicates) +# --------------------------------------------------------------------------- + + +def test_z3_negative_unprovable_inequality(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + # a < b does not hold for arbitrary a, b. + assert not analyzer.can_prove(a < b, SB) + # a * a > a is false (e.g. a == 0). + assert not analyzer.can_prove(a * a > a, SB) + + +def test_z3_truncmod_can_be_negative(): + # Regression test for truncated div/mod semantics: TVM Div/Mod round toward + # zero, so truncmod(a, 4) can be negative. A solver that modeled them as + # Euclidean would unsoundly "prove" truncmod(a, 4) >= 0. + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + assert not analyzer.can_prove(tirx.truncmod(a, 4) >= 0, SB) + + +def test_z3_truncdiv_truncmod_identity(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + expr = tirx.truncdiv(a, b) * b + tirx.truncmod(a, b) == a + with analyzer.constraint_scope(b != 0): + assert analyzer.can_prove(expr, SB) + + +def test_z3_floormod_nested_identities(): + # Ported from TileLang's test_divmod. Here `%` is floormod: nested floormod + # by opposite-sign divisors collapses to the single-divisor result, while + # the mixed case does not. + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + assert not analyzer.can_prove(a % 2 % -2 - a % 2 == 0, SB) + assert analyzer.can_prove(a % -2 % 2 - a % 2 == 0, SB) + + +def test_z3_floormod_nonnegative(): + # In contrast to truncmod, floormod with a positive divisor is always in + # [0, divisor), which Z3 should be able to prove. + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + assert analyzer.can_prove(tirx.floormod(a, 4) >= 0, SB) + assert analyzer.can_prove(tirx.floormod(a, 4) < 4, SB) + + +def test_z3_shift_does_not_poison_solver(): + # Regression test: evaluating a shift expression must not add permanent + # assertions (such as `b >= 0` / `b < 64`) to the shared solver. Otherwise + # an unrelated, unbounded `b` would be wrongly provable to be < 100. + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + + # Touch a shift expression so the prover visits the shift amount `b`. + analyzer.can_prove(tirx.shift_left(a, b) >= 0, SB) + + # `b` is otherwise unconstrained, so this must remain unprovable. + assert not analyzer.can_prove(b < 100, SB) + assert not analyzer.can_prove(b >= 0, SB) + + +def test_z3_constraint_scope_is_popped(): + # Constraints entered through a scope must be removed once the scope exits, + # i.e. EnterConstraint's solver.push()/pop() must be balanced. + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + with analyzer.constraint_scope(x > 5): + assert analyzer.can_prove(x > 0, SB) + # The constraint is gone; x is unconstrained again. + assert not analyzer.can_prove(x > 0, SB) + + +def test_z3_opaque_call_is_safe(): + # An opaque/unsupported sub-expression is modeled as a fresh free variable. + # It must neither crash nor be provable on its own, yet still be usable as a + # constraint. + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + call = tirx.call_extern("int32", "foo", x) + assert not analyzer.can_prove(call > 0, SB) + with analyzer.constraint_scope(call > 0): + assert analyzer.can_prove(call > 0, SB) + assert not analyzer.can_prove(call > 0, SB) + + +def test_z3_shift_overflow_is_not_proven(): + # Z3 models fixed-width shifts via bit-vectors, so it correctly refuses to + # prove `x << n >= x` for an unbounded `x` (a large `x` overflows int32 and + # wraps to a negative value). This guards against unsound shift modeling. + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + n = tirx.Var("n", "int32") + analyzer.set_z3_rlimit(0) + expr = implies(tirx.all(x >= 1, n >= 0, n < 8), tirx.shift_left(x, n) >= x) + assert not analyzer.can_prove(expr, SB) + + +def test_z3_analyzers_are_isolated(): + # Analyzers share a thread-local Z3 context but own separate solvers, so + # constraints and bindings in one must never leak into another. + analyzer_a = Analyzer() + analyzer_b = Analyzer() + _require_z3(analyzer_a) + + x = tirx.Var("x", "int32") + with analyzer_a.constraint_scope(x > 100): + assert analyzer_a.can_prove(x > 50, SB) + assert not analyzer_b.can_prove(x > 50, SB) + + analyzer_c = Analyzer() + analyzer_d = Analyzer() + analyzer_c.bind(x, tvm.ir.Range(0, 10)) + assert analyzer_c.can_prove(x < 10, SB) + assert not analyzer_d.can_prove(x < 10, SB) + + +def test_z3_repeated_can_prove_is_consistent(): + # Repeated queries must be stateless: a CanProve call must not pollute the + # solver and change the result of a subsequent call. + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + assert analyzer.can_prove(x > 0, SB) == analyzer.can_prove(x > 0, SB) + + analyzer.bind(x, tvm.ir.Range(5, 10)) + assert analyzer.can_prove(x >= 5, SB) + assert analyzer.can_prove(x >= 5, SB) + + +def test_z3_is_gated_behind_symbolic_bound(): + # The Z3 fallback must not run at the default strength. + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + expr = ((b - a) // c) * c + a <= b + with analyzer.constraint_scope(tirx.all(a > 0, b > 0, c > 0)): + assert not analyzer.can_prove(expr, ProofStrength.DEFAULT) + assert analyzer.can_prove(expr, SB) + + +# --------------------------------------------------------------------------- +# SMT-LIB2 export +# --------------------------------------------------------------------------- def test_z3_smtlib2_roundtrip(): z3 = pytest.importorskip("z3") - analyzer = tvm.arith.Analyzer() + analyzer = Analyzer() _require_z3(analyzer) a = tirx.Var("a", "int32") @@ -78,14 +703,20 @@ def test_z3_smtlib2_roundtrip(): assert solver.check() == z3.unsat -def test_z3_bitwise(): - analyzer = tvm.arith.Analyzer() +def test_z3_smtlib2_roundtrip_with_timeout(): + z3 = pytest.importorskip("z3") + analyzer = Analyzer() _require_z3(analyzer) - x = tirx.Var("x", "int32") - analyzer.bind(x, tvm.ir.Range(0, 256)) + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + analyzer.set_z3_timeout_ms(1000) - assert analyzer.can_prove(tirx.bitwise_and(x, tirx.IntImm("int32", 7)) < 8) + expr = implies(tirx.all(a > 0, b > 0, c > 0), ((b - a) // c) * c + a <= b) + solver = z3.Solver() + solver.from_string(analyzer.get_smtlib2(expr)) + assert solver.check() == z3.unsat if __name__ == "__main__": From d7f0470ce119f555591d991cc2d69ebacaa65586 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Wed, 10 Jun 2026 16:53:43 -0400 Subject: [PATCH 07/21] Drop accidentally committed build artifacts - Remove python/tvm/_version.py: it is auto-generated by the build's VCS versioning ("don't track in version control") and is not imported at runtime; the package __version__ comes from python/tvm/libinfo.py. Also add it to .gitignore so it is not re-committed. - Remove four stray 3rdparty gitlinks (cnpy, dmlc-core, rang, zlib) that are not real submodules (absent from .gitmodules). --- .gitignore | 3 +++ 3rdparty/cnpy | 1 - 3rdparty/dmlc-core | 1 - 3rdparty/rang | 1 - 3rdparty/zlib | 1 - python/tvm/_version.py | 24 ------------------------ 6 files changed, 3 insertions(+), 28 deletions(-) delete mode 160000 3rdparty/cnpy delete mode 160000 3rdparty/dmlc-core delete mode 160000 3rdparty/rang delete mode 160000 3rdparty/zlib delete mode 100644 python/tvm/_version.py diff --git a/.gitignore b/.gitignore index 180eea6c4ead..57f66ce6d6f6 100644 --- a/.gitignore +++ b/.gitignore @@ -299,3 +299,6 @@ pytest-of-bohanhou/ # tir-bench run artifacts (regenerable; see .claude/commands/tir-bench.md) .tir-bench/ .tir-bench-*/ + +# Auto-generated by the build's VCS versioning; must not be tracked. +/python/tvm/_version.py diff --git a/3rdparty/cnpy b/3rdparty/cnpy deleted file mode 160000 index 4e8810b1a863..000000000000 --- a/3rdparty/cnpy +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 4e8810b1a8637695171ed346ce68f6984e585ef4 diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core deleted file mode 160000 index 3031e4a61a98..000000000000 --- a/3rdparty/dmlc-core +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 3031e4a61a98f49f07a42cfdec6242340fb2fd8c diff --git a/3rdparty/rang b/3rdparty/rang deleted file mode 160000 index cabe04d6d6b0..000000000000 --- a/3rdparty/rang +++ /dev/null @@ -1 +0,0 @@ -Subproject commit cabe04d6d6b05356fa8f9741704924788f0dd762 diff --git a/3rdparty/zlib b/3rdparty/zlib deleted file mode 160000 index ef24c4c75021..000000000000 --- a/3rdparty/zlib +++ /dev/null @@ -1 +0,0 @@ -Subproject commit ef24c4c7502169f016dcd2a26923dbaf3216748c diff --git a/python/tvm/_version.py b/python/tvm/_version.py deleted file mode 100644 index 618c82e5a4ed..000000000000 --- a/python/tvm/_version.py +++ /dev/null @@ -1,24 +0,0 @@ -# file generated by vcs-versioning -# don't change, don't track in version control -from __future__ import annotations - -__all__ = [ - "__version__", - "__version_tuple__", - "version", - "version_tuple", - "__commit_id__", - "commit_id", -] - -version: str -__version__: str -__version_tuple__: tuple[int | str, ...] -version_tuple: tuple[int | str, ...] -commit_id: str | None -__commit_id__: str | None - -__version__ = version = '0.25.dev100' -__version_tuple__ = version_tuple = (0, 25, 'dev100') - -__commit_id__ = commit_id = 'g35152f312' From 8eb9412b120c929cfad5aadfb50512c7f7d7a582 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Wed, 10 Jun 2026 16:55:07 -0400 Subject: [PATCH 08/21] [CMAKE] Link Z3 statically from the z3-staticlib package by default USE_Z3=ON now resolves Z3 through the PyPI z3-staticlib package (PIC static libz3 + headers + CMake package files) so libtvm carries no runtime libz3 dependency. The z3-solver shared-library fallback is removed; a custom Z3 installation can still be selected via Z3_DIR/CMAKE_PREFIX_PATH. Wheel builds enable USE_Z3 by default and pull z3-staticlib as a build requirement. --- cmake/modules/contrib/Z3.cmake | 40 ++++++++++++++++++---------------- pyproject.toml | 5 ++++- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/cmake/modules/contrib/Z3.cmake b/cmake/modules/contrib/Z3.cmake index 7d0b746aa36f..441f7bf6c85d 100644 --- a/cmake/modules/contrib/Z3.cmake +++ b/cmake/modules/contrib/Z3.cmake @@ -22,32 +22,31 @@ if(NOT USE_Z3) return() endif() -find_package(Z3 QUIET) -set(Z3_PYTHON_RESULT 1) - -if(NOT Z3_FOUND) +# Default lookup: the PIC static Z3 library shipped by the PyPI `z3-static` +# package (headers + libz3.a + Z3 CMake package files). Linking it statically +# keeps libtvm free of a runtime libz3 dependency. Users can override the +# lookup by setting Z3_DIR/CMAKE_PREFIX_PATH to any Z3 installation (e.g. a +# shared system Z3). +if(NOT Z3_DIR) find_package(Python3 COMPONENTS Interpreter QUIET) if(Python3_EXECUTABLE) execute_process( - COMMAND "${Python3_EXECUTABLE}" -c "import z3; print(z3.__path__[0])" - OUTPUT_VARIABLE Z3_PYTHON_PACKAGE_DIR + COMMAND + "${Python3_EXECUTABLE}" -c + "import os, z3_static as m; f = getattr(m, 'get_cmake_dir', None); print(f() if f else os.path.join(os.path.dirname(m.__file__), 'static', 'lib', 'cmake', 'z3'))" + OUTPUT_VARIABLE Z3_STATIC_CMAKE_DIR OUTPUT_STRIP_TRAILING_WHITESPACE - RESULT_VARIABLE Z3_PYTHON_RESULT - ) - endif() - - if(Z3_PYTHON_RESULT EQUAL 0 AND NOT Z3_PYTHON_PACKAGE_DIR STREQUAL "") - find_path(Z3_INCLUDE_DIR NO_DEFAULT_PATH NAMES z3++.h PATHS "${Z3_PYTHON_PACKAGE_DIR}/include") - find_library( - Z3_LIBRARY - NO_DEFAULT_PATH - NAMES z3 libz3 - PATHS "${Z3_PYTHON_PACKAGE_DIR}/bin" "${Z3_PYTHON_PACKAGE_DIR}/lib" - "${Z3_PYTHON_PACKAGE_DIR}/lib64" + ERROR_QUIET + RESULT_VARIABLE Z3_STATIC_RESULT ) + if(Z3_STATIC_RESULT EQUAL 0 AND EXISTS "${Z3_STATIC_CMAKE_DIR}") + set(Z3_DIR "${Z3_STATIC_CMAKE_DIR}") + endif() endif() endif() +find_package(Z3 QUIET) + if(TARGET z3::libz3 OR TARGET Z3::libz3) if(TARGET z3::libz3) set(Z3_TARGET z3::libz3) @@ -72,7 +71,10 @@ elseif(Z3_FOUND OR (Z3_INCLUDE_DIR AND Z3_LIBRARY)) include_directories(SYSTEM ${Z3_INCLUDE_DIR}) list(APPEND TVM_LINKER_LIBS ${Z3_LIBRARY}) else() - message(FATAL_ERROR "USE_Z3 is ON, but Z3 was not found. Install Z3 or PyPI z3-solver.") + message(FATAL_ERROR + "USE_Z3 is ON, but Z3 was not found. Install the static Z3 development " + "package with `pip install z3-static`, or point Z3_DIR/CMAKE_PREFIX_PATH " + "at a Z3 installation.") endif() # Enable the real Z3 implementation inside the single src/arith/z3_prover.cc file. diff --git a/pyproject.toml b/pyproject.toml index 3c61fa389fc3..e0f600fdad95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,8 @@ # under the License. [build-system] -requires = ["scikit-build-core>=0.11", "setuptools-scm>=8"] +# z3-static ships the PIC static libz3 + headers consumed by USE_Z3=ON. +requires = ["scikit-build-core>=0.11", "setuptools-scm>=8", "z3-static"] build-backend = "scikit_build_core.build" [project] @@ -141,6 +142,8 @@ logging.level = "INFO" [tool.scikit-build.cmake.define] TVM_BUILD_PYTHON_MODULE = "ON" USE_CUDA = "OFF" +# Statically link Z3 from the z3-static build dependency by default. +USE_Z3 = "ON" BUILD_TESTING = "OFF" [tool.setuptools_scm] From 59054a652dae6cea0049c310aef9144f96b66f62 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Wed, 10 Jun 2026 16:55:26 -0400 Subject: [PATCH 09/21] Stop ignoring python/tvm/_version.py --- .gitignore | 3 --- 1 file changed, 3 deletions(-) diff --git a/.gitignore b/.gitignore index 57f66ce6d6f6..180eea6c4ead 100644 --- a/.gitignore +++ b/.gitignore @@ -299,6 +299,3 @@ pytest-of-bohanhou/ # tir-bench run artifacts (regenerable; see .claude/commands/tir-bench.md) .tir-bench/ .tir-bench-*/ - -# Auto-generated by the build's VCS versioning; must not be tracked. -/python/tvm/_version.py From 1a227a488d398ffdf837300a64fef67c351a37d7 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 14 Jun 2026 06:57:54 -0400 Subject: [PATCH 10/21] [CMAKE] Enable static z3 package by default Use the z3-static CMake package path first and make the rebased Z3 analyzer build cleanly with LLVM and CUDA enabled. --- CMakeLists.txt | 5 ++++- cmake/modules/LLVM.cmake | 8 ++++++++ cmake/modules/contrib/Z3.cmake | 5 ++++- include/tvm/arith/analyzer.h | 1 + 4 files changed, 17 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index b57ef919feb3..84ffce1f027e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,7 +90,7 @@ tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt") tvm_option(USE_BYODT_POSIT "Build with BYODT software emulated posit custom datatype" OFF) tvm_option(USE_BLAS "The blas library to be linked" none) tvm_option(USE_AMX "Enable Intel AMX" OFF) -tvm_option(USE_Z3 "Build with Z3 SMT solver support" OFF) +tvm_option(USE_Z3 "Build with Z3 SMT solver support" ON) tvm_option(USE_MKL "MKL root path when use MKL blas" OFF) tvm_option(USE_DNNL "Enable DNNL codegen" OFF) tvm_option(USE_CUDNN "Build with cuDNN" OFF) @@ -548,6 +548,9 @@ add_library(tvm_objs OBJECT ${COMPILER_SRCS}) add_library(tvm_runtime_objs OBJECT ${RUNTIME_SRCS}) target_link_libraries(tvm_objs PUBLIC tvm_ffi_header) target_link_libraries(tvm_runtime_objs PUBLIC tvm_ffi_header) +if(TARGET tvm_llvm_header) + target_link_libraries(tvm_objs PUBLIC tvm_llvm_header) +endif() include(GNUInstallDirs) diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake index f944b4130415..eecbc01c2999 100644 --- a/cmake/modules/LLVM.cmake +++ b/cmake/modules/LLVM.cmake @@ -34,6 +34,14 @@ if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN}) endif() include_directories(SYSTEM ${LLVM_INCLUDE_DIRS}) add_definitions(${LLVM_DEFINITIONS}) + add_library(tvm_llvm_header INTERFACE) + target_include_directories(tvm_llvm_header SYSTEM INTERFACE ${LLVM_INCLUDE_DIRS}) + set(TVM_LLVM_INCLUDE_FLAGS "") + foreach(__llvm_include_dir IN LISTS LLVM_INCLUDE_DIRS) + string(STRIP "${__llvm_include_dir}" __llvm_include_dir) + list(APPEND TVM_LLVM_INCLUDE_FLAGS "-isystem" "${__llvm_include_dir}") + endforeach() + target_compile_options(tvm_llvm_header INTERFACE ${TVM_LLVM_INCLUDE_FLAGS} ${LLVM_DEFINITIONS}) message(STATUS "Build with LLVM " ${LLVM_PACKAGE_VERSION}) message(STATUS "Set TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION}) # Set flags that are only needed for LLVM target diff --git a/cmake/modules/contrib/Z3.cmake b/cmake/modules/contrib/Z3.cmake index 441f7bf6c85d..4addfd4ccb78 100644 --- a/cmake/modules/contrib/Z3.cmake +++ b/cmake/modules/contrib/Z3.cmake @@ -45,7 +45,10 @@ if(NOT Z3_DIR) endif() endif() -find_package(Z3 QUIET) +find_package(Z3 CONFIG QUIET) +if(NOT Z3_FOUND AND NOT TARGET z3::libz3 AND NOT TARGET Z3::libz3) + find_package(Z3 QUIET) +endif() if(TARGET z3::libz3 OR TARGET Z3::libz3) if(TARGET z3::libz3) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index cbe9051e3b73..e635315e6714 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -684,6 +684,7 @@ class Z3Prover { int64_t min_consecutive = 1); private: + friend class AnalyzerObj; friend class Analyzer; explicit Z3Prover(AnalyzerObj* parent); TVM_DLL ~Z3Prover(); From 604801ceb33584225c4806d8e31fd8e77699dff3 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 14 Jun 2026 07:04:51 -0400 Subject: [PATCH 11/21] [CMAKE] Clean up LLVM include propagation Keep the LLVM header interface target on explicit compile options so object targets inherit llvm-config include paths reliably. --- cmake/modules/LLVM.cmake | 1 - 1 file changed, 1 deletion(-) diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake index eecbc01c2999..627b71124617 100644 --- a/cmake/modules/LLVM.cmake +++ b/cmake/modules/LLVM.cmake @@ -35,7 +35,6 @@ if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN}) include_directories(SYSTEM ${LLVM_INCLUDE_DIRS}) add_definitions(${LLVM_DEFINITIONS}) add_library(tvm_llvm_header INTERFACE) - target_include_directories(tvm_llvm_header SYSTEM INTERFACE ${LLVM_INCLUDE_DIRS}) set(TVM_LLVM_INCLUDE_FLAGS "") foreach(__llvm_include_dir IN LISTS LLVM_INCLUDE_DIRS) string(STRIP "${__llvm_include_dir}" __llvm_include_dir) From ff71350e85c187879a90c1371703d97a7e488f9c Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 14 Jun 2026 07:12:14 -0400 Subject: [PATCH 12/21] [CMAKE] Auto-enable Z3 when available Default USE_Z3 to AUTO so source builds use z3-static when installed while CI environments without the package still build the conservative stub. --- CMakeLists.txt | 2 +- cmake/modules/contrib/Z3.cmake | 19 ++++++++++++++----- 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 84ffce1f027e..65e645c596d2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,7 +90,7 @@ tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt") tvm_option(USE_BYODT_POSIT "Build with BYODT software emulated posit custom datatype" OFF) tvm_option(USE_BLAS "The blas library to be linked" none) tvm_option(USE_AMX "Enable Intel AMX" OFF) -tvm_option(USE_Z3 "Build with Z3 SMT solver support" ON) +tvm_option(USE_Z3 "Build with Z3 SMT solver support" AUTO) tvm_option(USE_MKL "MKL root path when use MKL blas" OFF) tvm_option(USE_DNNL "Enable DNNL codegen" OFF) tvm_option(USE_CUDNN "Build with cuDNN" OFF) diff --git a/cmake/modules/contrib/Z3.cmake b/cmake/modules/contrib/Z3.cmake index 4addfd4ccb78..4af6c6bf4571 100644 --- a/cmake/modules/contrib/Z3.cmake +++ b/cmake/modules/contrib/Z3.cmake @@ -18,10 +18,15 @@ # src/arith/z3_prover.cc is always part of COMPILER_SRCS (picked up by the # src/arith/*.cc glob). It compiles a conservative stub by default and switches # to the real Z3 implementation only when the TVM_USE_Z3 macro is defined below. -if(NOT USE_Z3) +if(${USE_Z3} MATCHES ${IS_FALSE_PATTERN}) return() endif() +set(TVM_Z3_REQUIRED TRUE) +if("${USE_Z3}" MATCHES "^[Aa][Uu][Tt][Oo]$") + set(TVM_Z3_REQUIRED FALSE) +endif() + # Default lookup: the PIC static Z3 library shipped by the PyPI `z3-static` # package (headers + libz3.a + Z3 CMake package files). Linking it statically # keeps libtvm free of a runtime libz3 dependency. Users can override the @@ -74,10 +79,14 @@ elseif(Z3_FOUND OR (Z3_INCLUDE_DIR AND Z3_LIBRARY)) include_directories(SYSTEM ${Z3_INCLUDE_DIR}) list(APPEND TVM_LINKER_LIBS ${Z3_LIBRARY}) else() - message(FATAL_ERROR - "USE_Z3 is ON, but Z3 was not found. Install the static Z3 development " - "package with `pip install z3-static`, or point Z3_DIR/CMAKE_PREFIX_PATH " - "at a Z3 installation.") + if(TVM_Z3_REQUIRED) + message(FATAL_ERROR + "USE_Z3 is ON, but Z3 was not found. Install the static Z3 development " + "package with `pip install z3-static`, or point Z3_DIR/CMAKE_PREFIX_PATH " + "at a Z3 installation.") + endif() + message(STATUS "Build without Z3 SMT solver support") + return() endif() # Enable the real Z3 implementation inside the single src/arith/z3_prover.cc file. From 4d9d57d749575b7cba4d6f6adc18f1d4cf32aec2 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 14 Jun 2026 08:15:19 -0400 Subject: [PATCH 13/21] [CMAKE] Use MSVC include dirs for LLVM headers Avoid passing GCC-style -isystem options to cl while keeping explicit LLVM include propagation for non-MSVC builds. --- cmake/modules/LLVM.cmake | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake index 627b71124617..0022dc59af3b 100644 --- a/cmake/modules/LLVM.cmake +++ b/cmake/modules/LLVM.cmake @@ -35,12 +35,17 @@ if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN}) include_directories(SYSTEM ${LLVM_INCLUDE_DIRS}) add_definitions(${LLVM_DEFINITIONS}) add_library(tvm_llvm_header INTERFACE) - set(TVM_LLVM_INCLUDE_FLAGS "") - foreach(__llvm_include_dir IN LISTS LLVM_INCLUDE_DIRS) - string(STRIP "${__llvm_include_dir}" __llvm_include_dir) - list(APPEND TVM_LLVM_INCLUDE_FLAGS "-isystem" "${__llvm_include_dir}") - endforeach() - target_compile_options(tvm_llvm_header INTERFACE ${TVM_LLVM_INCLUDE_FLAGS} ${LLVM_DEFINITIONS}) + if(MSVC) + target_include_directories(tvm_llvm_header SYSTEM INTERFACE ${LLVM_INCLUDE_DIRS}) + target_compile_options(tvm_llvm_header INTERFACE ${LLVM_DEFINITIONS}) + else() + set(TVM_LLVM_INCLUDE_FLAGS "") + foreach(__llvm_include_dir IN LISTS LLVM_INCLUDE_DIRS) + string(STRIP "${__llvm_include_dir}" __llvm_include_dir) + list(APPEND TVM_LLVM_INCLUDE_FLAGS "-isystem" "${__llvm_include_dir}") + endforeach() + target_compile_options(tvm_llvm_header INTERFACE ${TVM_LLVM_INCLUDE_FLAGS} ${LLVM_DEFINITIONS}) + endif() message(STATUS "Build with LLVM " ${LLVM_PACKAGE_VERSION}) message(STATUS "Set TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION}) # Set flags that are only needed for LLVM target From 8d04230a2641670e41b88aede86d145ebaaa696e Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 14 Jun 2026 09:55:53 -0400 Subject: [PATCH 14/21] [CMAKE] Document MSVC LLVM include handling Clarify why the LLVM header target uses include directories instead of GCC-style -isystem options on MSVC. --- cmake/modules/LLVM.cmake | 1 + 1 file changed, 1 insertion(+) diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake index 0022dc59af3b..8df8bacbacee 100644 --- a/cmake/modules/LLVM.cmake +++ b/cmake/modules/LLVM.cmake @@ -36,6 +36,7 @@ if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN}) add_definitions(${LLVM_DEFINITIONS}) add_library(tvm_llvm_header INTERFACE) if(MSVC) + # MSVC treats the operand after -isystem as a source file. target_include_directories(tvm_llvm_header SYSTEM INTERFACE ${LLVM_INCLUDE_DIRS}) target_compile_options(tvm_llvm_header INTERFACE ${LLVM_DEFINITIONS}) else() From a220ca8fd7089448ba0043c177376132eb8700a7 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 14 Jun 2026 11:18:06 -0400 Subject: [PATCH 15/21] [CMAKE] Clarify MSVC LLVM include handling --- cmake/modules/LLVM.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake index 8df8bacbacee..098d57033b4e 100644 --- a/cmake/modules/LLVM.cmake +++ b/cmake/modules/LLVM.cmake @@ -36,7 +36,7 @@ if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN}) add_definitions(${LLVM_DEFINITIONS}) add_library(tvm_llvm_header INTERFACE) if(MSVC) - # MSVC treats the operand after -isystem as a source file. + # MSVC treats GCC-style -isystem operands as source files. target_include_directories(tvm_llvm_header SYSTEM INTERFACE ${LLVM_INCLUDE_DIRS}) target_compile_options(tvm_llvm_header INTERFACE ${LLVM_DEFINITIONS}) else() From 75cd4da277655f40982187291d7236a605419c9e Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 14 Jun 2026 12:40:22 -0400 Subject: [PATCH 16/21] [CI] Retry GPU infrastructure check From 1c2952d2967a0707ec2315b897bb56cbea0332bf Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 14 Jun 2026 15:11:20 -0400 Subject: [PATCH 17/21] [TEST] Match reflected structural equal access paths --- tests/python/relax/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index 8eb3961ef8dd..5d30f89c07a6 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -143,7 +143,7 @@ def func_2(A: R.Tensor([16, 16], "float32")): with pytest.raises( ValueError, - match=re.escape(".body.blocks[0].bindings[0].value.op"), + match=r'key="body".*key="blocks".*key="bindings".*key="value".*key="op"', ): assert_structural_equal(func_1, func_2) From bfe04dbd86575540c4a77ad4f5f71dfe89018987 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 14 Jun 2026 18:52:35 -0400 Subject: [PATCH 18/21] [TEST] Lift rlimit for Z3 monotonicity proof --- tests/python/arith/test_arith_z3.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/arith/test_arith_z3.py b/tests/python/arith/test_arith_z3.py index f9dff4cac9a5..a7e947aa1865 100644 --- a/tests/python/arith/test_arith_z3.py +++ b/tests/python/arith/test_arith_z3.py @@ -231,6 +231,7 @@ def test_z3_floor_division_monotonicity(): c = tirx.Var("c", "int32") expr = implies(tirx.all(a <= b, c > 0), tirx.floordiv(a, c) <= tirx.floordiv(b, c)) assert not analyzer.can_prove(expr) + analyzer.set_z3_rlimit(0) assert analyzer.can_prove(expr, SB) From 55d503c2346f3f7aaee46c80a30f886cab622fda Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 14 Jun 2026 19:13:37 -0400 Subject: [PATCH 19/21] [TEST] Lift rlimit for Z3 index bound proof --- tests/python/arith/test_arith_z3.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/arith/test_arith_z3.py b/tests/python/arith/test_arith_z3.py index a7e947aa1865..97b7a0e065b6 100644 --- a/tests/python/arith/test_arith_z3.py +++ b/tests/python/arith/test_arith_z3.py @@ -271,6 +271,7 @@ def test_z3_flattened_index_bound(): expr = tirx.all(0 <= i * n + j, i * n + j < m * n) with analyzer.constraint_scope(tirx.all(0 <= i, i < m, 0 <= j, j < n, m > 0, n > 0)): assert not analyzer.can_prove(expr) + analyzer.set_z3_rlimit(0) assert analyzer.can_prove(expr, SB) From ea33484dd2979fd39c353e7aa4b622827942325d Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 14 Jun 2026 21:01:09 -0400 Subject: [PATCH 20/21] [ARITH] Pin Z3 context lifetime in prover --- src/arith/z3_prover.cc | 19 +++++++++--------- tests/python/arith/test_arith_z3.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/src/arith/z3_prover.cc b/src/arith/z3_prover.cc index 1a368c83904b..aab4b485dd70 100644 --- a/src/arith/z3_prover.cc +++ b/src/arith/z3_prover.cc @@ -103,11 +103,14 @@ class Z3Prover::Impl : ExprFunctor { using Self = Z3Prover::Impl; AnalyzerObj* analyzer; - /// @brief Z3 context, a shared ptr, because tilelang want to copy the Analyzer - // We use a thread_local static Z3 context so all analyzers within the same thread - // can share a common context, because Z3 initialization is slow on some CPUs - // (e.g., AMD EPYC 7502 32-Core). Using thread_local ensures thread safety. - inline static thread_local std::shared_ptr ctx{new z3::context()}; + // Keep a reference to the thread-local context for the whole lifetime of this + // prover. Schedules created on worker threads may be destroyed after the + // worker exits, so storing only a raw reference in z3::solver is not enough. + static std::shared_ptr GetThreadLocalContext() { + static thread_local std::shared_ptr local_ctx = std::make_shared(); + return local_ctx; + } + std::shared_ptr ctx{GetThreadLocalContext()}; /// @brief Z3 solver instance z3::solver solver{*ctx}; @@ -306,10 +309,8 @@ class Z3Prover::Impl : ExprFunctor { // because this->solver depends on this->ctx // we need to deconstruct the old solver, and create a new one depending on this->ctx solver = CreateSolver(*ctx); - // 2. ctx is a static thread_local pointer, so other_.ctx already refers to the same - // context on the current thread; there is nothing to copy here. Cross-thread copying - // of Z3Prover is not supported because Z3 expressions cannot be shared across different - // thread-local contexts without explicit translation. + // 2. ctx is owned by this Impl and pins the underlying thread-local context for the lifetime + // of solver and memoized expressions. // 3. copy other objects ns = other_.ns; for (auto& item : other_.memo_) { diff --git a/tests/python/arith/test_arith_z3.py b/tests/python/arith/test_arith_z3.py index 97b7a0e065b6..a64afd76c5b1 100644 --- a/tests/python/arith/test_arith_z3.py +++ b/tests/python/arith/test_arith_z3.py @@ -15,6 +15,10 @@ # specific language governing permissions and limitations # under the License. +import gc +import queue +import threading + import pytest import tvm @@ -62,6 +66,33 @@ def test_z3_capability_query(): analyzer.set_z3_rlimit(0) +def test_z3_context_lifetime_outlives_worker_thread(): + _require_z3(Analyzer()) + + result_queue = queue.Queue() + + def worker(): + try: + analyzer = Analyzer() + x = tirx.Var("x", "int32") + analyzer.bind(x, tvm.ir.Range(0, 16)) + assert analyzer.can_prove(x >= 0, SB) + result_queue.put(("analyzer", analyzer)) + except BaseException as err: # pylint: disable=broad-exception-caught + result_queue.put(("error", err)) + + thread = threading.Thread(target=worker) + thread.start() + thread.join() + + kind, payload = result_queue.get_nowait() + if kind == "error": + raise payload + + del payload + gc.collect() + + # --------------------------------------------------------------------------- # Examples the native analyzer cannot prove but Z3 can. # From aa10773c14959cccc3e397c9e64413d49f8d3771 Mon Sep 17 00:00:00 2001 From: Ubospica Date: Sun, 14 Jun 2026 22:12:32 -0400 Subject: [PATCH 21/21] [TEST] Tolerate reflected access path formatting --- tests/python/relax/test_utils.py | 30 ++++++++---------------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index 5d30f89c07a6..2013a05210be 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -16,8 +16,6 @@ # under the License. # ruff: noqa: F841 -import re - import pytest import tvm @@ -143,7 +141,7 @@ def func_2(A: R.Tensor([16, 16], "float32")): with pytest.raises( ValueError, - match=r'key="body".*key="blocks".*key="bindings".*key="value".*key="op"', + match=r"body.*blocks.*bindings.*value.*op", ): assert_structural_equal(func_1, func_2) @@ -251,25 +249,13 @@ def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"): return recursive_lambda(n) - # The path to the first mismatch, which should appear within the - # error message. - mismatch_path = [ - "", - "body", - "blocks[0]", - "bindings[0]", - "value", - "body", - "blocks[0]", - "bindings[0]", - "value", - "true_branch", - "body", - "value", - "value", - ] - - with pytest.raises(ValueError, match=re.escape(".".join(mismatch_path))): + mismatch_path = ( + r"body.*blocks.*bindings.*value" + r".*body.*blocks.*bindings.*value" + r".*true_branch.*body.*value.*value" + ) + + with pytest.raises(ValueError, match=mismatch_path): tvm.ir.assert_structural_equal(func_a, func_b)