diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi index 59da4c0b82af..98d0029dd4e0 160000 --- a/3rdparty/tvm-ffi +++ b/3rdparty/tvm-ffi @@ -1 +1 @@ -Subproject commit 59da4c0b82af0d499dae34bd89ef010f64d3ff45 +Subproject commit 98d0029dd4e002da1516d43f9b92e792f139e709 diff --git a/CMakeLists.txt b/CMakeLists.txt index 6e25f10e7f13..65e645c596d2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,6 +90,7 @@ tvm_option(COMPILER_RT_PATH "Path to COMPILER-RT" "3rdparty/compiler-rt") tvm_option(USE_BYODT_POSIT "Build with BYODT software emulated posit custom datatype" OFF) tvm_option(USE_BLAS "The blas library to be linked" none) tvm_option(USE_AMX "Enable Intel AMX" OFF) +tvm_option(USE_Z3 "Build with Z3 SMT solver support" AUTO) tvm_option(USE_MKL "MKL root path when use MKL blas" OFF) tvm_option(USE_DNNL "Enable DNNL codegen" OFF) tvm_option(USE_CUDNN "Build with cuDNN" OFF) @@ -439,9 +440,9 @@ include(cmake/utils/CCache.cmake) include(CheckCXXCompilerFlag) set(CMAKE_CXX_STANDARD_REQUIRED ON) -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CUDA_STANDARD_REQUIRED ON) -set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD 20) # Module rules include(cmake/modules/CUDA.cmake) @@ -460,6 +461,7 @@ include(cmake/modules/contrib/CUTLASS.cmake) include(cmake/modules/contrib/Random.cmake) include(cmake/modules/contrib/Posit.cmake) include(cmake/modules/contrib/Sort.cmake) +include(cmake/modules/contrib/Z3.cmake) include(cmake/modules/contrib/CoreML.cmake) include(cmake/modules/contrib/TensorRT.cmake) include(cmake/modules/contrib/NNAPI.cmake) @@ -546,6 +548,9 @@ add_library(tvm_objs OBJECT ${COMPILER_SRCS}) add_library(tvm_runtime_objs OBJECT ${RUNTIME_SRCS}) target_link_libraries(tvm_objs PUBLIC tvm_ffi_header) target_link_libraries(tvm_runtime_objs PUBLIC tvm_ffi_header) +if(TARGET tvm_llvm_header) + target_link_libraries(tvm_objs PUBLIC tvm_llvm_header) +endif() include(GNUInstallDirs) diff --git a/cmake/modules/LLVM.cmake b/cmake/modules/LLVM.cmake index f944b4130415..098d57033b4e 100644 --- a/cmake/modules/LLVM.cmake +++ b/cmake/modules/LLVM.cmake @@ -34,6 +34,19 @@ if(NOT ${USE_LLVM} MATCHES ${IS_FALSE_PATTERN}) endif() include_directories(SYSTEM ${LLVM_INCLUDE_DIRS}) add_definitions(${LLVM_DEFINITIONS}) + add_library(tvm_llvm_header INTERFACE) + if(MSVC) + # MSVC treats GCC-style -isystem operands as source files. + target_include_directories(tvm_llvm_header SYSTEM INTERFACE ${LLVM_INCLUDE_DIRS}) + target_compile_options(tvm_llvm_header INTERFACE ${LLVM_DEFINITIONS}) + else() + set(TVM_LLVM_INCLUDE_FLAGS "") + foreach(__llvm_include_dir IN LISTS LLVM_INCLUDE_DIRS) + string(STRIP "${__llvm_include_dir}" __llvm_include_dir) + list(APPEND TVM_LLVM_INCLUDE_FLAGS "-isystem" "${__llvm_include_dir}") + endforeach() + target_compile_options(tvm_llvm_header INTERFACE ${TVM_LLVM_INCLUDE_FLAGS} ${LLVM_DEFINITIONS}) + endif() message(STATUS "Build with LLVM " ${LLVM_PACKAGE_VERSION}) message(STATUS "Set TVM_LLVM_VERSION=" ${TVM_LLVM_VERSION}) # Set flags that are only needed for LLVM target diff --git a/cmake/modules/contrib/Z3.cmake b/cmake/modules/contrib/Z3.cmake new file mode 100644 index 000000000000..4af6c6bf4571 --- /dev/null +++ b/cmake/modules/contrib/Z3.cmake @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# src/arith/z3_prover.cc is always part of COMPILER_SRCS (picked up by the +# src/arith/*.cc glob). It compiles a conservative stub by default and switches +# to the real Z3 implementation only when the TVM_USE_Z3 macro is defined below. +if(${USE_Z3} MATCHES ${IS_FALSE_PATTERN}) + return() +endif() + +set(TVM_Z3_REQUIRED TRUE) +if("${USE_Z3}" MATCHES "^[Aa][Uu][Tt][Oo]$") + set(TVM_Z3_REQUIRED FALSE) +endif() + +# Default lookup: the PIC static Z3 library shipped by the PyPI `z3-static` +# package (headers + libz3.a + Z3 CMake package files). Linking it statically +# keeps libtvm free of a runtime libz3 dependency. Users can override the +# lookup by setting Z3_DIR/CMAKE_PREFIX_PATH to any Z3 installation (e.g. a +# shared system Z3). +if(NOT Z3_DIR) + find_package(Python3 COMPONENTS Interpreter QUIET) + if(Python3_EXECUTABLE) + execute_process( + COMMAND + "${Python3_EXECUTABLE}" -c + "import os, z3_static as m; f = getattr(m, 'get_cmake_dir', None); print(f() if f else os.path.join(os.path.dirname(m.__file__), 'static', 'lib', 'cmake', 'z3'))" + OUTPUT_VARIABLE Z3_STATIC_CMAKE_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + RESULT_VARIABLE Z3_STATIC_RESULT + ) + if(Z3_STATIC_RESULT EQUAL 0 AND EXISTS "${Z3_STATIC_CMAKE_DIR}") + set(Z3_DIR "${Z3_STATIC_CMAKE_DIR}") + endif() + endif() +endif() + +find_package(Z3 CONFIG QUIET) +if(NOT Z3_FOUND AND NOT TARGET z3::libz3 AND NOT TARGET Z3::libz3) + find_package(Z3 QUIET) +endif() + +if(TARGET z3::libz3 OR TARGET Z3::libz3) + if(TARGET z3::libz3) + set(Z3_TARGET z3::libz3) + else() + set(Z3_TARGET Z3::libz3) + endif() + get_target_property(Z3_TARGET_INCLUDE_DIRS ${Z3_TARGET} INTERFACE_INCLUDE_DIRECTORIES) + if(Z3_TARGET_INCLUDE_DIRS) + include_directories(SYSTEM ${Z3_TARGET_INCLUDE_DIRS}) + endif() + list(APPEND TVM_LINKER_LIBS ${Z3_TARGET}) +elseif(Z3_FOUND OR (Z3_INCLUDE_DIR AND Z3_LIBRARY)) + if(NOT Z3_INCLUDE_DIR AND Z3_CXX_INCLUDE_DIRS) + set(Z3_INCLUDE_DIR ${Z3_CXX_INCLUDE_DIRS}) + endif() + if(NOT Z3_LIBRARY AND Z3_LIBRARIES) + set(Z3_LIBRARY ${Z3_LIBRARIES}) + endif() + if(NOT Z3_INCLUDE_DIR OR NOT Z3_LIBRARY) + message(FATAL_ERROR "USE_Z3 is ON, but Z3 include directory or library was not found.") + endif() + include_directories(SYSTEM ${Z3_INCLUDE_DIR}) + list(APPEND TVM_LINKER_LIBS ${Z3_LIBRARY}) +else() + if(TVM_Z3_REQUIRED) + message(FATAL_ERROR + "USE_Z3 is ON, but Z3 was not found. Install the static Z3 development " + "package with `pip install z3-static`, or point Z3_DIR/CMAKE_PREFIX_PATH " + "at a Z3 installation.") + endif() + message(STATUS "Build without Z3 SMT solver support") + return() +endif() + +# Enable the real Z3 implementation inside the single src/arith/z3_prover.cc file. +add_compile_definitions(TVM_USE_Z3) +message(STATUS "Build with Z3 SMT solver support") diff --git a/cmake/utils/FindLLVM.cmake b/cmake/utils/FindLLVM.cmake index 2bf229eca756..1f54ded1d5e9 100644 --- a/cmake/utils/FindLLVM.cmake +++ b/cmake/utils/FindLLVM.cmake @@ -254,9 +254,9 @@ macro(find_llvm use_llvm) # compiler-appropriate form so the probe works under MSVC as well. if(NOT CMAKE_CXX_STANDARD) if(MSVC) - set(CMAKE_REQUIRED_FLAGS "/std:c++17") + set(CMAKE_REQUIRED_FLAGS "/std:c++20") else() - set(CMAKE_REQUIRED_FLAGS "-std=c++17") + set(CMAKE_REQUIRED_FLAGS "-std=c++20") endif() endif() check_cxx_source_compiles(" diff --git a/docs/install/from_source.rst b/docs/install/from_source.rst index a970bf5c1e9e..392fdc5cfc5e 100644 --- a/docs/install/from_source.rst +++ b/docs/install/from_source.rst @@ -35,11 +35,15 @@ Apache TVM requires the following dependencies: - CMake (>= 3.24.0) - LLVM (recommended >= 15) - Git -- A recent C++ compiler supporting C++ 17, at the minimum - - GCC 7.1 - - Clang 5.0 - - Apple Clang 9.3 - - Visual Studio 2019 (v16.7) +- A recent C++ compiler supporting C++ 20, at the minimum + - GCC 10 + - Clang 10 + - Apple Clang 14 + - Visual Studio 2022 + + Optional dependencies that use newer C++20 standard library facilities, such + as ``std::format``, may require a newer standard library (for example GCC 13 + or newer on Linux). - Python (>= 3.10) - (Optional) Conda (Strongly Recommended) diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h index 924cc299270a..e635315e6714 100644 --- a/include/tvm/arith/analyzer.h +++ b/include/tvm/arith/analyzer.h @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -588,6 +589,110 @@ class IntSetAnalyzer { Impl* impl_; }; +class Z3Prover { + public: + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param new_range The range of allowed values for this var. + * \param allow_override whether we allow override of existing information. + */ + TVM_DLL void Bind(const Var& var, const Range& new_range, bool allow_override = false); + + /*! + * \brief Update binding of var to a new expression. + * + * \param var The variable of interest. + * \param expr The bound expression. + * \param allow_override whether we allow override of existing information. + */ + TVM_DLL void Bind(const Var& var, const PrimExpr& expr, bool allow_override = false); + + /*! + * \brief Whether the Z3 backend is compiled into this build (USE_Z3=ON). + * + * \return true if the real Z3 prover is available, false for the stub. + */ + TVM_DLL bool IsEnabled() const; + + /*! + * \brief Whether can we prove expr is always true. + * + * \param expr The expression. + * \return Whether we can prove it. + */ + TVM_DLL bool CanProve(const PrimExpr& expr); + + /*! + * \brief Update the internal state to enter constraint. + * + * \param constraint A constraint expression. + * \return an exit function that must be called to cleanup the constraint can be nullptr. + */ + std::function EnterConstraint(const PrimExpr& constraint); + + /*! + * \brief Get the SMTLIB2 representation of the current context. + * + * \param expr The optional expression to check. + * \return The SMTLIB2 string. + */ + ffi::String GetSMTLIB2(const ffi::Optional expr); + + /*! + * \brief Get statistics about Z3 prover. + * + * \return The statistics string. + */ + ffi::String GetStats(); + + /*! + * \brief Set timeout in milliseconds for Z3 prover. + * + * \param timeout_ms The timeout in milliseconds. + */ + void SetTimeoutMs(unsigned timeout_ms); + + /*! + * \brief Set resource limitation for Z3 prover. + * + * \param rlimit the resource limitation. + */ + void SetRLimit(unsigned rlimit); + + /*! + * \brief Get the Z3 model for the given expression if satisfiable. + * + * \param expr The expression to get the model for. + * \return The model as a string. + */ + ffi::String GetModel(const PrimExpr& expr); + + /*! + * \brief Count the number of integer values that satisfy the current constraints. + * + * This method uses Z3's model enumeration to count how many distinct values of + * the given variable satisfy all current constraints. + * + * \param var The variable to count satisfying values for. + * \param max_count Maximum number of solutions to enumerate. + * \param min_consecutive Minimum consecutive count requirement. + * \return The number of distinct values that satisfy the constraints, or a negative error code. + */ + TVM_DLL int64_t CountSatisfyingValues(const Var& var, int64_t max_count = 2048, + int64_t min_consecutive = 1); + + private: + friend class AnalyzerObj; + friend class Analyzer; + explicit Z3Prover(AnalyzerObj* parent); + TVM_DLL ~Z3Prover(); + void CopyFrom(const Z3Prover& other); + class Impl; + Impl* impl_; +}; + /*! * \brief Analyzer that contains bunch of sub-analyzers. * @@ -612,6 +717,8 @@ class TVM_DLL AnalyzerObj : public ffi::Object { IntSetAnalyzer int_set; /*! \brief sub-analyzer transitive comparisons */ TransitiveComparisonAnalyzer transitive_comparisons; + /*! \brief sub-analyzer using Z3 */ + Z3Prover z3_prover; /*! \brief constructor */ AnalyzerObj(); /*! @@ -810,7 +917,16 @@ class ConstraintContext { * \param constraint The constraint to be applied. */ ConstraintContext(const Analyzer& analyzer, PrimExpr constraint) - : analyzer_(analyzer), constraint_(constraint) {} + : ConstraintContext(analyzer, std::move(constraint), false) {} + /*! + * \brief Construct a constraint context. + * \param analyzer The analyzer whose context is updated. The context + * keeps a reference to the analyzer while the scope is active. + * \param constraint The constraint to be applied. + * \param is_assume Whether the constraint comes from an assumption. + */ + ConstraintContext(const Analyzer& analyzer, PrimExpr constraint, bool is_assume) + : analyzer_(analyzer), constraint_(std::move(constraint)), is_assume_(is_assume) {} /*! * \brief Construct a constraint context from a borrowed analyzer object. * \param analyzer The borrowed analyzer object. @@ -819,7 +935,15 @@ class ConstraintContext { * This overload is for internal callers that already operate on AnalyzerObj*. */ ConstraintContext(AnalyzerObj* analyzer, PrimExpr constraint) - : ConstraintContext(ffi::GetRef(analyzer), std::move(constraint)) {} + : ConstraintContext(ffi::GetRef(analyzer), std::move(constraint), false) {} + /*! + * \brief Construct a constraint context from a borrowed analyzer object. + * \param analyzer The borrowed analyzer object. + * \param constraint The constraint to be applied. + * \param is_assume Whether the constraint comes from an assumption. + */ + ConstraintContext(AnalyzerObj* analyzer, PrimExpr constraint, bool is_assume) + : ConstraintContext(ffi::GetRef(analyzer), std::move(constraint), is_assume) {} // enter the scope. void EnterWithScope(); // exit the scope. @@ -830,6 +954,8 @@ class ConstraintContext { PrimExpr constraint_; /*! \brief functions to be called in recovery */ std::vector> recovery_functions_; + /*! \brief Whether the constraint comes from an assumption. */ + bool is_assume_; }; } // namespace arith diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h index 60b292bbb265..231b129b94be 100644 --- a/include/tvm/tirx/op.h +++ b/include/tvm/tirx/op.h @@ -825,7 +825,8 @@ inline bool IsPointerType(const Type& type, const DataType& element_type) { * \param span The location of this operation in the source. */ template ::value>::type> + typename = typename std::enable_if::value && + std::is_trivial::value>::type> inline PrimExpr make_const(DataType t, ValueType value, Span span = Span()); /*! * \brief Make a const zero expr. diff --git a/pyproject.toml b/pyproject.toml index 3c61fa389fc3..e0f600fdad95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,8 @@ # under the License. [build-system] -requires = ["scikit-build-core>=0.11", "setuptools-scm>=8"] +# z3-static ships the PIC static libz3 + headers consumed by USE_Z3=ON. +requires = ["scikit-build-core>=0.11", "setuptools-scm>=8", "z3-static"] build-backend = "scikit_build_core.build" [project] @@ -141,6 +142,8 @@ logging.level = "INFO" [tool.scikit-build.cmake.define] TVM_BUILD_PYTHON_MODULE = "ON" USE_CUDA = "OFF" +# Statically link Z3 from the z3-static build dependency by default. +USE_Z3 = "ON" BUILD_TESTING = "OFF" [tool.setuptools_scm] diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py index 0aa6a75eba4a..78e93395c382 100644 --- a/python/tvm/arith/analyzer.py +++ b/python/tvm/arith/analyzer.py @@ -128,6 +128,91 @@ class Analyzer(Object): def __init__(self): self.__init_handle_by_constructor__(_ffi_api.Analyzer) + @property + def is_z3_enabled(self) -> bool: + """Whether this build includes the Z3 backend (``USE_Z3=ON``). + + The Z3-specific methods (:py:meth:`get_smtlib2`, :py:meth:`get_z3_stats`, + :py:meth:`set_z3_timeout_ms`, :py:meth:`set_z3_rlimit`) only work when + this is ``True``. + """ + return bool(_ffi_api.AnalyzerIsZ3Enabled(self)) + + def _check_z3_enabled(self) -> None: + if not self.is_z3_enabled: + raise RuntimeError( + "The Z3 backend is not available in this build. " + "Rebuild TVM with USE_Z3=ON to use Z3-specific Analyzer APIs." + ) + + def get_smtlib2(self, expr: tirx.PrimExpr | None = None) -> str: + """Get the current Z3 problem in SMT-LIB2 format. + + Raises + ------ + RuntimeError + If TVM was built without Z3 (``USE_Z3=OFF``), since there is no + solver state to export. Use :py:attr:`is_z3_enabled` to check first. + + Parameters + ---------- + expr : Optional[PrimExpr] + The expression to prove. If provided, its negation is added to the problem. + """ + self._check_z3_enabled() + return _ffi_api.AnalyzerGetSMTLIB2(self, expr) + + def set_z3_timeout_ms(self, timeout_ms: int) -> None: + """Set Z3 timeout in milliseconds. + + Raises + ------ + RuntimeError + If TVM was built without Z3 (``USE_Z3=OFF``). + + Parameters + ---------- + timeout_ms : int + The timeout in milliseconds. + """ + self._check_z3_enabled() + _ffi_api.AnalyzerSetZ3TimeoutMs(self, timeout_ms) + + def set_z3_rlimit(self, rlimit: int) -> None: + """Set Z3 resource limit. + + The resource limit gives deterministic solver budgeting (unlike a wall + clock timeout). A value of ``0`` disables the limit. + + Raises + ------ + RuntimeError + If TVM was built without Z3 (``USE_Z3=OFF``). + + Parameters + ---------- + rlimit : int + The resource limit. + """ + self._check_z3_enabled() + _ffi_api.AnalyzerSetZ3RLimit(self, rlimit) + + def get_z3_stats(self) -> str: + """Get Z3 solver statistics. + + Raises + ------ + RuntimeError + If TVM was built without Z3 (``USE_Z3=OFF``). + + Returns + ------- + stats : str + The Z3 statistics. + """ + self._check_z3_enabled() + return _ffi_api.AnalyzerGetZ3Stats(self) + def const_int_bound(self, expr: tirx.PrimExpr) -> ConstIntBound: """Find constant integer bound for expr. @@ -260,7 +345,9 @@ def can_prove( The expression. strength: ProofStrength - The proof strength + The proof strength. When TVM is built with Z3 (``USE_Z3=ON``), the + optional Z3 fallback is only consulted at ``SYMBOLIC_BOUND`` or + higher, after the native analyzers fail to prove the predicate. Returns ------- diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc index cc3c73bb6207..8d2807dfcbb9 100644 --- a/src/arith/analyzer.cc +++ b/src/arith/analyzer.cc @@ -39,7 +39,8 @@ AnalyzerObj::AnalyzerObj() modular_set(this), rewrite_simplify(this), canonical_simplify(this), - int_set(this) {} + int_set(this), + z3_prover(this) {} void AnalyzerObj::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { PrimExpr new_expr = expr; @@ -52,6 +53,7 @@ void AnalyzerObj::Bind(const Var& var, const PrimExpr& expr, bool allow_override this->canonical_simplify.Update(var, new_expr, allow_override); this->int_set.Update(var, this->int_set(new_expr), allow_override); this->transitive_comparisons.Bind(var, expr, allow_override); + this->z3_prover.Bind(var, expr, allow_override); } void AnalyzerObj::Bind(const Var& var, const Range& range, bool allow_override) { @@ -62,6 +64,7 @@ void AnalyzerObj::Bind(const Var& var, const Range& range, bool allow_override) this->const_int_bound.Bind(var, range, allow_override); this->int_set.Bind(var, range, allow_override); this->transitive_comparisons.Bind(var, range, allow_override); + this->z3_prover.Bind(var, range, allow_override); } // skip modular_set // skip rewrite simplify @@ -131,6 +134,7 @@ void ConstraintContext::EnterWithScope() { recovery_functions_.push_back(analyzer_->rewrite_simplify.EnterConstraint(constraint_)); recovery_functions_.push_back(analyzer_->int_set.EnterConstraint(constraint_)); recovery_functions_.push_back(analyzer_->transitive_comparisons.EnterConstraint(constraint_)); + recovery_functions_.push_back(analyzer_->z3_prover.EnterConstraint(constraint_)); } void ConstraintContext::ExitWithScope() { @@ -231,6 +235,12 @@ bool AnalyzerObj::CanProve(const PrimExpr& expr, ProofStrength strength) { } } + // Z3 is an expensive best-effort fallback. Gate it behind the higher + // kSymbolicBound strength so the common kDefault path (including deeply + // recursive internal CanProve calls) never pays the prover cost. + if (strength >= ProofStrength::kSymbolicBound && z3_prover.CanProve(simplified)) { + return true; + } return false; } @@ -334,6 +344,20 @@ TVM_FFI_STATIC_INIT_BLOCK() { return static_cast( analyzer->transitive_comparisons.TryCompare(lhs, rhs, propagate_inequalities)); }) + .def("arith.AnalyzerIsZ3Enabled", + [](Analyzer analyzer) { return analyzer->z3_prover.IsEnabled(); }) + .def("arith.AnalyzerGetSMTLIB2", + [](Analyzer analyzer, ffi::Optional expr) { + return analyzer->z3_prover.GetSMTLIB2(expr); + }) + .def("arith.AnalyzerSetZ3TimeoutMs", [](Analyzer analyzer, int64_t timeout_ms) { + analyzer->z3_prover.SetTimeoutMs(static_cast(timeout_ms)); + }) + .def("arith.AnalyzerSetZ3RLimit", [](Analyzer analyzer, int64_t rlimit) { + analyzer->z3_prover.SetRLimit(static_cast(rlimit)); + }) + .def("arith.AnalyzerGetZ3Stats", + [](Analyzer analyzer) { return analyzer->z3_prover.GetStats(); }) .def("arith.AnalyzerGetEnabledExtensions", [](Analyzer analyzer) { return static_cast(analyzer->rewrite_simplify.GetEnabledExtensions()); diff --git a/src/arith/z3_prover.cc b/src/arith/z3_prover.cc new file mode 100644 index 000000000000..aab4b485dd70 --- /dev/null +++ b/src/arith/z3_prover.cc @@ -0,0 +1,864 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/arith/z3_prover.cc + * \brief Optional Z3 SMT solver backend for arith::Analyzer. + * + * The real implementation is compiled only when TVM_USE_Z3 is defined (set by + * the USE_Z3 CMake option). Otherwise a conservative stub is compiled so the + * C++ and Python APIs stay available without a Z3 dependency. + */ +#ifdef TVM_USE_Z3 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tvm/ffi/cast.h" +#include "tvm/ffi/object.h" +#include "tvm/ffi/string.h" +#include "tvm/ir/expr.h" +#include "tvm/runtime/data_type.h" +#include "z3++.h" + +namespace tvm::arith { + +using namespace tirx; +using namespace ffi; + +namespace { + +struct Namespace { + std::unordered_set used_names; + /// @brief Get a new name that is not used before + /// This function is used to generate z3 variable names + /// + /// Z3 may deduplicate variables with the same name, which + /// causes issues when different TVM variables are mapped to + /// the same z3 variable. + /// + /// This function generates unique names by appending + /// suffixes to the original expression string representation. + /// + /// such as : "x", "x$1", "x$2", ... + std::string GetNewName(const PrimExpr& expr) { + std::stringstream ss; + ss << expr; + auto name = ss.str(); + if (used_names.count(name) == 0) { + used_names.insert(name); + return name; + } + int idx = 1; + std::string check_name = name + "$" + std::to_string(idx); + while (used_names.count(check_name)) { + idx++; + check_name = name + "$" + std::to_string(idx); + } + used_names.insert(check_name); + return check_name; + } +}; + +} // namespace + +class Z3Prover::Impl : ExprFunctor { + public: + using Base = ExprFunctor; + using Self = Z3Prover::Impl; + + AnalyzerObj* analyzer; + // Keep a reference to the thread-local context for the whole lifetime of this + // prover. Schedules created on worker threads may be destroyed after the + // worker exits, so storing only a raw reference in z3::solver is not enough. + static std::shared_ptr GetThreadLocalContext() { + static thread_local std::shared_ptr local_ctx = std::make_shared(); + return local_ctx; + } + std::shared_ptr ctx{GetThreadLocalContext()}; + + /// @brief Z3 solver instance + z3::solver solver{*ctx}; + + /// @brief Memorize pure expressions + std::unordered_map memo_; + + /// @brief Namespace for variable naming + Namespace ns; + + /// @brief Timeout in milliseconds + unsigned timeout_ms{UINT_MAX}; + + /// @brief Max steps + unsigned rlimit{UINT_MAX}; + + /// @brief Create a z3 solver with custom options + static z3::solver CreateSolver(z3::context& ctx) { + z3::solver solver(ctx); + // here we disable model generation to speed up the solving process + solver.set("model", false); + // ensure determinstic behavior + solver.set("random_seed", (unsigned)42); + return solver; + } + + Impl(AnalyzerObj* parent) : analyzer(parent) { + scope_stack_.push_back({}); + solver = CreateSolver(*ctx); + // use rlimit, not timeout to ensure deterministic behavior + SetRLimit(10000U); + } + + /// @brief Create a Free z3 expression from PrimExprNode + z3::expr Create(const PrimExprNode* op) { + auto ref = ffi::GetRef(op); + auto dtype = op->dtype; + std::string name = ns.GetNewName(ref); + /// TVM max_val can't handle uint64 max correctly, so we special case it here + if (dtype.is_bool()) { + return ctx->bool_const(name.c_str()); + } else { + z3::expr e = ctx->int_const(name.c_str()); + if (dtype.is_uint() && dtype.bits() == 64) { + solver.add(ctx->int_val(0) <= e && e <= ctx->int_val((uint64_t)UINT64_MAX)); + } else { + auto min_val = Downcast(min_value(dtype))->value; + auto max_val = Downcast(max_value(dtype))->value; + solver.add(ctx->int_val(min_val) <= e && e <= ctx->int_val(max_val)); + } + return e; + } + } + + struct Scope { + enum Kind { + BindValue, + BindRange, + Constraint, + } kind; + Var var; + PrimExpr value; + PrimExpr min; + PrimExpr extent; + PrimExpr constraint; + }; + + /// @brief scope_stack memorizes existing constraint and bindings + /// to generate SMTLIB2 representation with comments + std::vector> scope_stack_; + + /// @brief Enter a constraint scope + std::function EnterConstraint(const PrimExpr& constraint) { + scope_stack_.push_back({}); + scope_stack_.back().push_back( + Scope{Scope::Constraint, Var(), PrimExpr(), PrimExpr(), PrimExpr(), constraint}); + solver.push(); + solver.add(VisitBool(constraint)); + auto side_effect_exprs = std::move(side_effect_exprs_); + side_effect_exprs_.clear(); + for (const auto& expr : side_effect_exprs) { + memo_.erase(expr); + } + return [this]() { + solver.pop(); + scope_stack_.pop_back(); + }; + } + + /// @brief Check trivil bad cases, return true if the expr is a bad case + /// Z3 prover may take a long time to initialize (at least 200us), + /// This optimization can speedup 30% of the test cases in our unit tests + bool CheckTrivilBadCases(const PrimExpr& expr) { + if (IsFreeNode(expr)) { + return true; + } + auto checkTrivilCmp = [this](const PrimExpr& lhs, const PrimExpr& rhs) { + if (IsFreeNode(lhs) && rhs->IsInstance()) { + return true; + } + if (IsFreeNode(rhs) && lhs->IsInstance()) { + return true; + } + if (IsFreeNode(lhs) && IsFreeNode(rhs)) { + return true; + } + // cast('xxx', free_var) == constant + if (auto cast = lhs.as()) { + if (IsFreeNode(cast->value) && rhs->IsInstance()) { + return true; + } + } + // constant == cast('xxx', free_var) + if (auto cast = rhs.as()) { + if (IsFreeNode(cast->value) && lhs->IsInstance()) { + return true; + } + } + return false; + }; + if (auto eq = expr.as()) { + auto lhs = eq->a; + auto rhs = eq->b; + return checkTrivilCmp(lhs, rhs); + } else if (auto ne = expr.as()) { + auto lhs = ne->a; + auto rhs = ne->b; + return checkTrivilCmp(lhs, rhs); + } + return false; + } + + /// @brief Check if the expression can be proved + bool CanProve(const PrimExpr& expr) { + // Z3 is only a fallback. Any failure (including z3::exception thrown by the + // solver) must degrade to "cannot prove" instead of escaping to the caller. + try { + if (CheckTrivilBadCases(expr)) return false; + if (!IsValidDType(expr->dtype)) return false; + z3::expr_vector constr(*ctx); + constr.push_back(!ConvertBool(expr)); + auto result = solver.check(constr); + constr.pop_back(); + return result == z3::unsat; + } catch (const z3::exception&) { + return false; + } + } + + /// @brief Binded + /// @brief Bind a variable to a value or a range + void Bind(const Var& var, const PrimExpr& value, bool allow_override = false) { + if (!IsValidDType(var->dtype)) return; + scope_stack_.back().push_back(Scope{Scope::BindValue, var, value}); + // we add the binding whenever the value is pure, + // because non-pure parts are handling by creating free variables in VisitExpr + memo_.emplace(var, ConvertInt(value)); + } + + /// @brief Bind a variable to a range + void Bind(const Var& var, const Range& range, bool allow_override = false) { + if (!IsValidDType(var->dtype)) return; + scope_stack_.back().push_back( + Scope{Scope::BindRange, var, PrimExpr(), range->min, range->extent}); + // 1. Create a placeholder for the var, and save it in the memo + // if the var is overrided later, we can just update the memo, and the old placeholder will + // be ignored + auto var_expr = Create(var.as()); + memo_.emplace(var, var_expr); + + // 2. Add constraint on the placeholder + // when min_expr >= max_expr, the range is empty, which is under undefined behavior + // instead of adding an unsat constraint, we just skip the range constraint to leave it a + // free var + // + // NOTE: range->min + range->extent builds a fresh AddNode that is not folded, so we must + // test is_const_int on range->min and range->extent individually and add the two constants + // in C++. Otherwise this fast path is never taken and we always emit the more expensive + // symbolic constraint below. + if (tirx::is_const_int(range->min) && tirx::is_const_int(range->extent)) { + int64_t min_value = *tirx::as_const_int(range->min); + int64_t extent_value = *tirx::as_const_int(range->extent); + int64_t max_value = min_value + extent_value; + if (min_value < max_value) { + solver.add(ctx->int_val(min_value) <= var_expr); + solver.add(var_expr < ctx->int_val(max_value)); + } + } else { + solver.add(ConvertBool(range->extent <= 0 || + (range->min <= var && var < range->min + range->extent))); + } + } + + void CopyFrom(const Self& other_) { + // 1. create a new solver + // because this->solver depends on this->ctx + // we need to deconstruct the old solver, and create a new one depending on this->ctx + solver = CreateSolver(*ctx); + // 2. ctx is owned by this Impl and pins the underlying thread-local context for the lifetime + // of solver and memoized expressions. + // 3. copy other objects + ns = other_.ns; + for (auto& item : other_.memo_) { + memo_.emplace(item.first, item.second); + } + for (auto a : other_.solver.assertions()) { + solver.add(a); + } + // 4. copy timeout options + // but other solver options are not copied + SetTimeoutMs(other_.timeout_ms); + SetRLimit(other_.rlimit); + // 5. copy the scope stack, which containing comments for SMTLIB2 generation + scope_stack_ = other_.scope_stack_; + } + + /// @brief Set timeout in milliseconds + void SetTimeoutMs(unsigned timeout_ms) { + this->timeout_ms = timeout_ms; + solver.set("timeout", timeout_ms); + } + + /// @brief Set max steps + void SetRLimit(unsigned rlimit) { + this->rlimit = rlimit; + solver.set("rlimit", rlimit); + } + + /// @brief Get the SMTLIB2 representation of the current solver state + ffi::String GetSMTLIB2() { + std::stringstream ss; + ss << "(set-option :timeout " << timeout_ms << ")\n"; + AddScopeDebugMsg(ss); + ss << solver.to_smt2(); + return ss.str(); + } + + void AddScopeDebugMsg(std::ostream& ss) { + for (const auto& scope : scope_stack_) { + ss << "; Entering Scope\n"; + for (const auto& s : scope) { + switch (s.kind) { + case Scope::Constraint: + ss << "; constraint: " << s.constraint << "\n"; + break; + case Scope::BindValue: + ss << "; bind value: " << s.var << " = " << s.value << "\n"; + break; + case Scope::BindRange: + ss << "; bind range: " << s.var << " in [" << s.min << ", " << s.min + s.extent + << ")\n"; + break; + } + } + } + } + + /// @brief Get the SMTLIB2 representation of the current solver state with additional expr trying + /// to prove + ffi::String GetSMTLIB2(const PrimExpr& expr) { + std::stringstream ss; + ss << "(set-option :timeout " << timeout_ms << ")\n"; + AddScopeDebugMsg(ss); + ss << "; Trying to prove: " << expr << "\n"; + solver.push(); + solver.add(!ConvertBool(expr)); + ss << solver.to_smt2(); + solver.pop(); + return ss.str(); + } + + /// @brief Get the statistics of the solver + ffi::String GetStats() { + std::stringstream ss; + ss << solver.statistics(); + return ss.str(); + } + + ffi::String GetModel(const PrimExpr& expr) { + solver.set("model", true); + solver.push(); + solver.add(!ConvertBool(expr)); + auto result = solver.check(); + ffi::String model_str; + if (result == z3::sat) { + z3::model m = solver.get_model(); + std::map model_map; + for (unsigned i = 0; i < m.size(); i++) { + z3::func_decl d = m[i]; + model_map.emplace(d.name().str(), m.get_const_interp(d)); + } + std::stringstream ss; + for (const auto& [k, v] : model_map) { + ss << " " << k << " = " << v << "\n"; + } + model_str = ss.str(); + } + solver.pop(); + solver.set("model", false); + return model_str; + } + + /*! + * \brief Count the number of distinct integer values satisfying current constraints. + * + * Uses Z3's model enumeration (AllSAT pattern) to count solutions: + * 1. Find a satisfying assignment + * 2. Add a blocking clause to exclude it + * 3. Repeat until UNSAT + * + * \param var The variable to count values for + * \param max_count Safety limit on enumeration + * \param min_consecutive Minimum consecutive count requirement (0 to disable) + * \return Number of satisfying values, -1 on error, -2 if min_consecutive constraint not met + */ + int64_t CountSatisfyingValues(const Var& var, int64_t max_count, int64_t min_consecutive = 1) { + if (!IsValidDType(var->dtype)) { + return -1; + } + + solver.set("model", true); + solver.push(); + + // Convert the TVM variable to Z3 expression + z3::expr z3_var = VisitInt(var); + + int64_t count = 0; + std::vector found_values; + + while (count < max_count) { + auto result = solver.check(); + if (result != z3::sat) { + break; // No more solutions + } + + z3::model m = solver.get_model(); + z3::expr val_expr = m.eval(z3_var, true); + + // Extract the integer value from Z3 expression + int64_t val; + if (val_expr.is_numeral()) { + val = val_expr.get_numeral_int64(); + } else { + // If we can't get a concrete value, stop enumeration + break; + } + + found_values.push_back(val); + count++; + + // Add blocking clause: var != val (exclude this solution) + solver.add(z3_var != ctx->int_val(val)); + } + + solver.pop(); + solver.set("model", false); + + // Clear any side effects from visiting the variable + for (const auto& expr : side_effect_exprs_) { + memo_.erase(expr); + } + side_effect_exprs_.clear(); + + // Check minimum consecutive constraint if enabled + if (min_consecutive > 0 && count > 0) { + // Sort the values to check consecutive groups + std::sort(found_values.begin(), found_values.end()); + + // Check that all values form groups of at least min_consecutive consecutive numbers + int64_t consecutive_count = 1; + for (size_t i = 1; i < found_values.size(); i++) { + if (found_values[i] == found_values[i - 1] + 1) { + // Consecutive value + consecutive_count++; + } else { + // Gap found, check if the previous group meets the minimum + if (consecutive_count < min_consecutive) { + return -2; // Previous group too small + } + consecutive_count = 1; // Start new group + } + } + // Check the last group + if (consecutive_count < min_consecutive) { + return -2; // Last group too small + } + } + + return count; + } + + private: + using Z3BinOp = z3::expr (*)(const z3::expr&, const z3::expr&); + + std::vector side_effect_exprs_; + + z3::expr ConvertBool(const PrimExpr& e) { + auto res = VisitBool(e); + for (auto& expr : side_effect_exprs_) { + memo_.erase(expr); + } + side_effect_exprs_.clear(); + return res; + } + + z3::expr ConvertInt(const PrimExpr& e) { + auto res = VisitInt(e); + for (auto& expr : side_effect_exprs_) { + memo_.erase(expr); + } + side_effect_exprs_.clear(); + return res; + } + + /// @brief Visit expression with memoization + z3::expr VisitExpr(const PrimExpr& e) override { + if (memo_.count(e)) { + return memo_.at(e); + } + auto res = Base::VisitExpr(e); + auto side_effect = SideEffect(e); + if (side_effect <= CallEffectKind::kPure) { + memo_.emplace(e, res); + } else if (side_effect <= CallEffectKind::kReadState) { + memo_.emplace(e, res); + side_effect_exprs_.emplace_back(e); + } else { + side_effect_exprs_.emplace_back(e); + } + return res; + } + + /// @brief Check if the expression is a free node having no constraints + bool IsFreeNode(const PrimExpr& e) { + if (memo_.count(e)) { + return false; + } + return e->IsInstance() || e->IsInstance() || + e->IsInstance() || e->IsInstance() || + (e->IsInstance() && !IsValidDType(Downcast(e)->value->dtype)); + } + + /// @brief Check if the dtype is valid for z3 integer operations + static bool IsValidDType(const DataType& dtype) { + return (dtype.is_int() || dtype.is_uint() || dtype.is_bool()) && dtype.lanes() == 1; + } + + /// @brief Visit the expression and convert it into z3 integer expression + z3::expr VisitInt(const PrimExpr& expr) { + auto e = VisitExpr(expr); + if (e.is_bool()) { + return z3::ite(e, ctx->int_val(1), ctx->int_val(0)); + } else { + return e; + } + } + + /// @brief Visit the expression and convert it into z3 boolean expression + z3::expr VisitBool(const PrimExpr& e) { + auto expr = VisitExpr(e); + if (expr.is_bool()) { + return expr; + } else { + return expr != ctx->int_val(0); + } + } + + /// @brief Helper function to visit binary arithmetic operations + z3::expr VisitArith(Z3BinOp signed_op, const PrimExprNode* op, const PrimExpr& a, + const PrimExpr& b) { + if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + return signed_op(VisitInt(a), VisitInt(b)); + } else { + return Create(op); + } + } + + z3::expr VisitExpr_(const LetNode* op) override { + if (IsValidDType(op->var->dtype)) { + memo_.emplace(op->var, VisitInt(op->value)); + } + return VisitExpr(op->body); + } + z3::expr VisitExpr_(const CastNode* op) override { + // if the inner dtype is valid, we just visit it + if (IsValidDType(op->value->dtype) && IsValidDType(op->dtype)) { + return VisitInt(op->value); + } else { + // otherwise, we create a new free z3 variable + return Create(op); + } + } + z3::expr VisitExpr_(const VarNode* op) override { return Create(op); } + z3::expr VisitExpr_(const BufferLoadNode* op) override { return Create(op); } + z3::expr VisitExpr_(const ProducerLoadNode* op) override { return Create(op); } + z3::expr VisitExpr_(const ReduceNode* op) override { return Create(op); } + z3::expr VisitExpr_(const MinNode* op) override { + auto a = VisitInt(op->a); + auto b = VisitInt(op->b); + return z3::ite(a < b, a, b); + } + z3::expr VisitExpr_(const MaxNode* op) override { + auto a = VisitInt(op->a); + auto b = VisitInt(op->b); + return z3::ite(a > b, a, b); + } + // TVM Div/Mod are truncated (round toward zero), while Z3's native operator/ + // and operator% are Euclidean. Using the raw operators is unsound once the + // dividend can be negative, so we implement truncating helpers explicitly. + static z3::expr truncdiv(const z3::expr& a, const z3::expr& b) { + z3::expr abs_a = z3::ite(a >= 0, a, -a); + z3::expr abs_b = z3::ite(b >= 0, b, -b); + // |a| / |b| is exact (Euclidean == truncated for non-negative operands). + z3::expr q = abs_a / abs_b; + return z3::ite((a >= 0) == (b >= 0), q, -q); + } + static z3::expr truncmod(const z3::expr& a, const z3::expr& b) { + // TVM Mod follows the sign of the dividend: a - b * truncdiv(a, b). + return a - b * truncdiv(a, b); + } + static z3::expr floordiv(const z3::expr& a, const z3::expr& b) { + return z3::ite(b > 0, a / b, -((-a) / b)); + } + static z3::expr floormod(const z3::expr& a, const z3::expr& b) { + return z3::ite(b > 0, a % b, -((-a) % b)); + } + z3::expr VisitExpr_(const AddNode* op) override { + return VisitArith(z3::operator+, op, op->a, op->b); + } + z3::expr VisitExpr_(const SubNode* op) override { + return VisitArith(z3::operator-, op, op->a, op->b); + } + z3::expr VisitExpr_(const MulNode* op) override { + return VisitArith(z3::operator*, op, op->a, op->b); + } + z3::expr VisitExpr_(const DivNode* op) override { return VisitArith(truncdiv, op, op->a, op->b); } + z3::expr VisitExpr_(const ModNode* op) override { return VisitArith(truncmod, op, op->a, op->b); } + z3::expr VisitExpr_(const FloorDivNode* op) override { + return VisitArith(floordiv, op, op->a, op->b); + } + z3::expr VisitExpr_(const FloorModNode* op) override { + return VisitArith(floormod, op, op->a, op->b); + } + z3::expr VisitExpr_(const EQNode* op) override { + return VisitArith(z3::operator==, op, op->a, op->b); + } + z3::expr VisitExpr_(const NENode* op) override { + return VisitArith(z3::operator!=, op, op->a, op->b); + } + z3::expr VisitExpr_(const LTNode* op) override { + return VisitArith(z3::operator<, op, op->a, op->b); + } + z3::expr VisitExpr_(const LENode* op) override { + return VisitArith(z3::operator<=, op, op->a, op->b); + } + z3::expr VisitExpr_(const GTNode* op) override { + return VisitArith(z3::operator>, op, op->a, op->b); + } + z3::expr VisitExpr_(const GENode* op) override { + return VisitArith(z3::operator>=, op, op->a, op->b); + } + z3::expr VisitExpr_(const AndNode* op) override { return VisitBool(op->a) && VisitBool(op->b); } + z3::expr VisitExpr_(const OrNode* op) override { return VisitBool(op->a) || VisitBool(op->b); } + z3::expr VisitExpr_(const NotNode* op) override { return !VisitBool(op->a); } + z3::expr VisitExpr_(const SelectNode* op) override { + return z3::ite(VisitBool(op->condition), VisitInt(op->true_value), VisitInt(op->false_value)); + } + z3::expr VisitExpr_(const IntImmNode* op) override { return ctx->int_val(op->value); } + + // Bitwise operations + z3::expr VisitExpr_(const CallNode* op) override { + // Check if this is a bitwise operation + if (op->op.same_as(tirx::builtin::bitwise_and())) { + return VisitBitwiseOp(z3::operator&, op); + } else if (op->op.same_as(tirx::builtin::bitwise_or())) { + return VisitBitwiseOp(z3::operator|, op); + } else if (op->op.same_as(tirx::builtin::bitwise_xor())) { + return VisitBitwiseOp(z3::operator^, op); + } else if (op->op.same_as(tirx::builtin::bitwise_not())) { + return VisitBitwiseNotOp(op); + } else if (op->op.same_as(tirx::builtin::shift_left())) { + return VisitShiftOp(z3::shl, op); + } else if (op->op.same_as(tirx::builtin::shift_right())) { + return VisitShiftOp(z3::ashr, op); + } else if (op->op.same_as(tirx::builtin::if_then_else()) && op->args.size() == 3 && + IsValidDType(op->args[1]->dtype) && IsValidDType(op->args[2]->dtype)) { + // tir.if_then_else(cond, a, b) is a select-like ternary. + return z3::ite(VisitBool(op->args[0]), VisitInt(op->args[1]), VisitInt(op->args[2])); + } else { + // For other call nodes, create a free variable + return Create(op); + } + } + + /// @brief Helper function to visit binary bitwise operations + z3::expr VisitBitwiseOp(z3::expr (*op_func)(const z3::expr&, const z3::expr&), + const CallNode* op) { + if (op->args.size() != 2) { + LOG(FATAL) << "Binary bitwise operation expects 2 arguments, got " << op->args.size(); + TVM_FFI_UNREACHABLE(); + } + + const PrimExpr& a = op->args[0]; + const PrimExpr& b = op->args[1]; + unsigned bit_width = std::max(op->args[0].dtype().bits(), op->args[1].dtype().bits()); + + if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + return z3::bv2int( + op_func(z3::int2bv(bit_width, VisitInt(a)), z3::int2bv(bit_width, VisitInt(b))), true); + } else { + return Create(op); + } + } + + /// @brief Helper function to visit unary bitwise not operation + z3::expr VisitBitwiseNotOp(const CallNode* op) { + if (op->args.size() != 1) { + LOG(FATAL) << "Bitwise not operation expects 1 argument, got " << op->args.size(); + TVM_FFI_UNREACHABLE(); + } + + const PrimExpr& a = op->args[0]; + + if (IsValidDType(a->dtype)) { + // Cast integer to bit-vector, apply bitwise not, then cast back. + unsigned bit_width = a.dtype().bits(); + z3::expr a_int = VisitInt(a); + z3::expr a_bv = z3::int2bv(bit_width, a_int); + return z3::bv2int(~a_bv, true); + } else { + return Create(op); + } + } + + /// @brief Helper function to visit shift operations + z3::expr VisitShiftOp(z3::expr (*op_func)(const z3::expr&, const z3::expr&), const CallNode* op) { + if (op->args.size() != 2) { + LOG(FATAL) << "Shift operation expects 2 arguments, got " << op->args.size(); + TVM_FFI_UNREACHABLE(); + } + + const PrimExpr& a = op->args[0]; + const PrimExpr& b = op->args[1]; + + // Shift operations require integer types for both operands + if (IsValidDType(a->dtype) && IsValidDType(b->dtype)) { + z3::expr a_expr = VisitInt(a); + z3::expr b_expr = VisitInt(b); + + // Rely on Z3's native bit-vector shift behavior. We must NOT add hard + // assertions such as `b_expr >= 0` to the solver here: solver.add() has no + // matching push/pop in this path, so the assertion would permanently + // poison the shared solver and make all subsequent unrelated proofs about + // `b` unsound. + unsigned bit_width = std::max(a.dtype().bits(), b.dtype().bits()); + z3::expr a_bv = z3::int2bv(bit_width, a_expr); + z3::expr b_bv = z3::int2bv(bit_width, b_expr); + + // Perform the shift in bit-vector domain, then cast back to int. + z3::expr result_bv = op_func(a_bv, b_bv); + return z3::bv2int(result_bv, true); + } else { + return Create(op); + } + } + + z3::expr VisitExprDefault_(const Object* op) override { + // Z3 is a best-effort fallback that runs only after the native analyzers + // have already failed. An unsupported node must not crash the build, so we + // model it as a fresh unconstrained free variable, which keeps the proof + // sound (it can only make CanProve more conservative). + return Create(static_cast(op)); + } +}; + +TVM_DLL bool Z3Prover::IsEnabled() const { return true; } +TVM_DLL bool Z3Prover::CanProve(const PrimExpr& expr) { return impl_->CanProve(expr); } +TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool allow_override) { + return impl_->Bind(var, new_range, allow_override); +} +TVM_DLL void Z3Prover::Bind(const Var& var, const PrimExpr& expr, bool allow_override) { + return impl_->Bind(var, expr, allow_override); +} +std::function Z3Prover::EnterConstraint(const PrimExpr& constraint) { + return impl_->EnterConstraint(constraint); +} +ffi::String Z3Prover::GetSMTLIB2(const ffi::Optional expr) { + if (expr.has_value()) { + return impl_->GetSMTLIB2(expr.value()); + } else { + return impl_->GetSMTLIB2(); + } +} +void Z3Prover::SetTimeoutMs(unsigned timeout_ms) { impl_->SetTimeoutMs(timeout_ms); } +void Z3Prover::SetRLimit(unsigned max_step) { impl_->SetRLimit(max_step); } +void Z3Prover::CopyFrom(const Z3Prover& other) { impl_->CopyFrom(*other.impl_); } +ffi::String Z3Prover::GetStats() { return impl_->GetStats(); } +ffi::String Z3Prover::GetModel(const PrimExpr& expr) { return impl_->GetModel(expr); } +TVM_DLL int64_t Z3Prover::CountSatisfyingValues(const Var& var, int64_t max_count, + int64_t min_consecutive) { + return impl_->CountSatisfyingValues(var, max_count, min_consecutive); +} +Z3Prover::Z3Prover(AnalyzerObj* parent) : impl_(new Impl{parent}) {} +TVM_DLL Z3Prover::~Z3Prover() { delete impl_; } + +} // namespace tvm::arith + +#else // TVM_USE_Z3 + +#include +#include +#include + +#include "tvm/ffi/string.h" +#include "tvm/ir/expr.h" + +namespace tvm::arith { + +using namespace tirx; +using namespace ffi; + +// Stub implementation used when Z3 support is not built. All proving queries +// conservatively report "cannot prove" while keeping the public API available. +class Z3Prover::Impl {}; + +TVM_DLL bool Z3Prover::IsEnabled() const { return false; } +TVM_DLL bool Z3Prover::CanProve(const PrimExpr& expr) { return false; } +TVM_DLL void Z3Prover::Bind(const Var& var, const Range& new_range, bool allow_override) {} +TVM_DLL void Z3Prover::Bind(const Var& var, const PrimExpr& expr, bool allow_override) {} +std::function Z3Prover::EnterConstraint(const PrimExpr& constraint) { + return []() {}; +} +ffi::String Z3Prover::GetSMTLIB2(const ffi::Optional expr) { + return "; Z3 Prover is disabled."; +} +void Z3Prover::SetTimeoutMs(unsigned timeout_ms) {} +void Z3Prover::SetRLimit(unsigned rlimit) {} +ffi::String Z3Prover::GetModel(const PrimExpr& expr) { return "; Z3 Prover is disabled."; } +TVM_DLL int64_t Z3Prover::CountSatisfyingValues(const Var& var, int64_t max_count, + int64_t min_consecutive) { + return -1; // Z3 disabled, return error +} + +void Z3Prover::CopyFrom(const Z3Prover& other) {} +ffi::String Z3Prover::GetStats() { return "; Z3 Prover is disabled."; } +Z3Prover::Z3Prover(AnalyzerObj*) : impl_(nullptr) {} +TVM_DLL Z3Prover::~Z3Prover() {} + +} // namespace tvm::arith + +#endif // TVM_USE_Z3 diff --git a/src/relax/transform/run_codegen.cc b/src/relax/transform/run_codegen.cc index bc99196169e7..efd90d6696d7 100644 --- a/src/relax/transform/run_codegen.cc +++ b/src/relax/transform/run_codegen.cc @@ -155,7 +155,7 @@ class CodeGenRunner : ExprMutator { if (opt_codegen) { auto ext_symbol = GetExtSymbol(func); size_t count = 0; - PostOrderVisit(func->body, [=, &count](Expr e) { + PostOrderVisit(func->body, [=, this, &count](Expr e) { if (e->IsInstance()) { // Make sure to pick a unique name auto name = ext_symbol + "_" + opt_codegen.value() + "_const_" + std::to_string(count++); diff --git a/src/runtime/extra/contrib/cudnn/cudnn_json_runtime.cc b/src/runtime/extra/contrib/cudnn/cudnn_json_runtime.cc index df38f960d294..a7cf1ec2b318 100644 --- a/src/runtime/extra/contrib/cudnn/cudnn_json_runtime.cc +++ b/src/runtime/extra/contrib/cudnn/cudnn_json_runtime.cc @@ -162,7 +162,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { conv_dtype, false, &best_algo); int algo = best_algo.cast(); - std::function op_exec = [=]() { + std::function op_exec = [=, this]() { int device_id; CUDA_CALL(cudaGetDevice(&device_id)); cudaStream_t stream = static_cast(TVMFFIEnvGetStream(kDLCUDA, device_id)); @@ -223,7 +223,7 @@ class cuDNNJSONRuntime : public JSONRuntimeBase { auto runner = tvm::contrib::CuDNNSDPARunner::Create(); runner->Init(batch, seq_len, num_heads, num_kv_heads, head_size, head_size_v, scale, dtype, layout); - return [=]() { + return [=, this]() { auto qkv = GetInput(node, 0); auto workspace = const_cast(GetInput(node, 1)); auto out = const_cast(data_entry_[EntryID(outputs_[0])]); diff --git a/src/runtime/extra/disco/protocol.h b/src/runtime/extra/disco/protocol.h index 25662051dcb4..a26b3060bc2a 100644 --- a/src/runtime/extra/disco/protocol.h +++ b/src/runtime/extra/disco/protocol.h @@ -28,6 +28,7 @@ #include #include +#include #include #include @@ -78,7 +79,8 @@ struct DiscoProtocol { /*!\ brief Arena used by RPCReference to allocate POD memory */ template T* ArenaAlloc(int count) { - static_assert(std::is_pod::value, "need to be trival"); + static_assert(std::is_standard_layout::value && std::is_trivial::value, + "need to be trivial"); return arena_.template allocate_(count); } diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index a6950117d611..0402430251e5 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -34,6 +34,7 @@ #include #include #include +#include #include #include @@ -312,7 +313,8 @@ class RPCEndpoint::EventHandler : public support::Stream { template T* ArenaAlloc(int count) { - static_assert(std::is_pod::value, "need to be trival"); + static_assert(std::is_standard_layout::value && std::is_trivial::value, + "need to be trivial"); return arena_.template allocate_(count); } diff --git a/src/s_tir/transform/compact_buffer_region.cc b/src/s_tir/transform/compact_buffer_region.cc index d02e90701696..640a18d594cf 100644 --- a/src/s_tir/transform/compact_buffer_region.cc +++ b/src/s_tir/transform/compact_buffer_region.cc @@ -129,6 +129,8 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } private: + using StmtExprVisitor::VisitBufferDef; + struct BufferAccessInfo { /*! \brief The buffer. */ Buffer buffer; diff --git a/src/s_tir/transform/inject_software_pipeline.cc b/src/s_tir/transform/inject_software_pipeline.cc index d9da151f392f..5099f66cd030 100644 --- a/src/s_tir/transform/inject_software_pipeline.cc +++ b/src/s_tir/transform/inject_software_pipeline.cc @@ -707,7 +707,7 @@ class PipelineRewriter : public StmtExprMutator { } } - auto wait_count = [=, &ana_normalized]() { + auto wait_count = [=, this, &ana_normalized]() { auto sum = PrimExpr(0); for (auto producer_head : producer_head_per_commit) { if (producer_head && ana_normalized->CanProve(producer_head.value() >= 0)) { diff --git a/src/target/llvm/codegen_params.cc b/src/target/llvm/codegen_params.cc index fccc92a22830..6d8684a87eda 100644 --- a/src/target/llvm/codegen_params.cc +++ b/src/target/llvm/codegen_params.cc @@ -61,7 +61,8 @@ struct LLVMConstantGetter::value>> static llvm::Constant* getElement(llvm::Type* ty, T t) { return llvm::ConstantFP::get(ty, t); } }; -template ::value>> +template ::value && std::is_trivial::value>> void BuildLLVMVector(llvm::Type* element_type, void* tensor_data, size_t num_elements, std::vector* elements) { elements->resize(num_elements, nullptr); diff --git a/tests/python/arith/test_arith_z3.py b/tests/python/arith/test_arith_z3.py new file mode 100644 index 000000000000..a64afd76c5b1 --- /dev/null +++ b/tests/python/arith/test_arith_z3.py @@ -0,0 +1,756 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import gc +import queue +import threading + +import pytest + +import tvm +import tvm.testing +from tvm import tirx +from tvm.arith import Analyzer, ProofStrength + +# The Z3 prover is only consulted at the kSymbolicBound strength so the common +# default path never pays the prover cost. +SB = ProofStrength.SYMBOLIC_BOUND + + +def _require_z3(analyzer): + if not analyzer.is_z3_enabled: + pytest.skip("Z3 prover is disabled in this build") + + +def implies(x, y): + return tirx.Or(tirx.Not(x), y) + + +# --------------------------------------------------------------------------- +# API availability (works regardless of whether Z3 is built) +# --------------------------------------------------------------------------- + + +def test_z3_capability_query(): + # `is_z3_enabled` is the supported way to detect the build configuration. + # The Z3-specific debug/config methods work only when it is True, and raise + # a clear error otherwise. + analyzer = Analyzer() + assert isinstance(analyzer.is_z3_enabled, bool) + + if analyzer.is_z3_enabled: + assert isinstance(analyzer.get_smtlib2(), str) + assert isinstance(analyzer.get_z3_stats(), str) + else: + with pytest.raises(RuntimeError): + analyzer.get_smtlib2() + with pytest.raises(RuntimeError): + analyzer.get_z3_stats() + with pytest.raises(RuntimeError): + analyzer.set_z3_timeout_ms(1000) + with pytest.raises(RuntimeError): + analyzer.set_z3_rlimit(0) + + +def test_z3_context_lifetime_outlives_worker_thread(): + _require_z3(Analyzer()) + + result_queue = queue.Queue() + + def worker(): + try: + analyzer = Analyzer() + x = tirx.Var("x", "int32") + analyzer.bind(x, tvm.ir.Range(0, 16)) + assert analyzer.can_prove(x >= 0, SB) + result_queue.put(("analyzer", analyzer)) + except BaseException as err: # pylint: disable=broad-exception-caught + result_queue.put(("error", err)) + + thread = threading.Thread(target=worker) + thread.start() + thread.join() + + kind, payload = result_queue.get_nowait() + if kind == "error": + raise payload + + del payload + gc.collect() + + +# --------------------------------------------------------------------------- +# Examples the native analyzer cannot prove but Z3 can. +# +# Each case asserts both that the native analyzers (kDefault, Z3 gated off) +# fail and that Z3 (kSymbolicBound) succeeds. This demonstrates the added value +# of the Z3 backend and that it is correctly gated behind kSymbolicBound. +# --------------------------------------------------------------------------- + + +def test_z3_floor_division_identity_constraint(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + + expr = ((b - a) // c) * c + a <= b + with analyzer.constraint_scope(tirx.all(a > 0, b > 0, c > 0)): + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_floor_division_identity_via_bind_range(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + + analyzer.bind(a, tvm.ir.Range(1, 100000)) + analyzer.bind(b, tvm.ir.Range(1, 100000)) + analyzer.bind(c, tvm.ir.Range(1, 100000)) + + expr = ((b - a) // c) * c + a <= b + assert analyzer.can_prove(expr, SB) + + +def test_z3_multiplication_monotonicity(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + d = tirx.Var("d", "int32") + + expr = implies(tirx.all(a < b, b < c, a * d < b * d), b * d < c * d) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_nested_floor_division_collapse(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + expr = implies( + tirx.all(a >= 0, a < 128), + a // 128 == (a // 64 * 32 + a % 32 // 16 * 8) // 64, + ) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_deeply_nested_floor_division_identity(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + expr = implies( + tirx.all(a >= 0, a < 128), + ( + a % 16 * 64 + + a // 64 * 32 + + a % 8 // 4 * 32 + + (a % 32 // 16 + a % 2) % 2 * 8 + + 16 + - (a // 64 + a % 8 // 4) // 2 * 64 + ) + // 512 + == ( + a % 16 * 64 + + a // 64 * 32 + + a % 8 // 4 * 32 + + (a % 32 // 16 + a % 2) % 2 * 8 + - (a // 64 + a % 8 // 4) // 2 * 64 + ) + // 512, + ) + assert analyzer.can_prove(expr, SB) + + +def test_z3_min_max_sum_identity(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + expr = tirx.max(x, y) + tirx.min(x, y) == x + y + assert analyzer.can_prove(expr, SB) + + +def test_z3_select_absolute_value_nonneg(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + expr = tirx.Select(x >= 0, x, -x) >= 0 + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_transitive_inequality(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + expr = implies(tirx.all(a <= b, b <= c), a <= c) + assert analyzer.can_prove(expr, SB) + + +def test_z3_square_expansion_nonneg(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + expr = (a + b) * (a + b) >= a * a + b * b + with analyzer.constraint_scope(tirx.all(a >= 0, b >= 0)): + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_square_monotonicity(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + expr = implies(tirx.all(0 <= a, a <= b), a * a <= b * b) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_strict_multiplication(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + d = tirx.Var("d", "int32") + expr = implies(tirx.all(a < b, d > 0), a * d < b * d) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_floor_division_monotonicity(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + expr = implies(tirx.all(a <= b, c > 0), tirx.floordiv(a, c) <= tirx.floordiv(b, c)) + assert not analyzer.can_prove(expr) + analyzer.set_z3_rlimit(0) + assert analyzer.can_prove(expr, SB) + + +def test_z3_floor_division_lower_bound(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + expr = implies(b > 0, tirx.floordiv(a, b) * b <= a) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_floor_modulo_range(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + expr = implies(b > 0, tirx.all(0 <= tirx.floormod(a, b), tirx.floormod(a, b) < b)) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_flattened_index_bound(): + # Classic index-flattening bound used throughout TVM: for a row index i in + # [0, m) and a column index j in [0, n), the flattened index i * n + j stays + # within [0, m * n). + analyzer = Analyzer() + _require_z3(analyzer) + + i = tirx.Var("i", "int32") + j = tirx.Var("j", "int32") + m = tirx.Var("m", "int32") + n = tirx.Var("n", "int32") + expr = tirx.all(0 <= i * n + j, i * n + j < m * n) + with analyzer.constraint_scope(tirx.all(0 <= i, i < m, 0 <= j, j < n, m > 0, n > 0)): + assert not analyzer.can_prove(expr) + analyzer.set_z3_rlimit(0) + assert analyzer.can_prove(expr, SB) + + +def test_z3_modular_combination(): + # Native modular_set tracks single-variable moduli, but combining two + # independent modular facts to reason about their sum is left to Z3. + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + expr = tirx.floormod(x + y, 2) == 0 + with analyzer.constraint_scope(tirx.all(tirx.floormod(x, 6) == 0, tirx.floormod(y, 6) == 0)): + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_square_non_negative(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + assert not analyzer.can_prove(a * a >= 0) + assert analyzer.can_prove(a * a >= 0, SB) + + +def test_z3_min_max_average_bounds(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + assert not analyzer.can_prove(tirx.max(a, b) * 2 >= a + b) + assert analyzer.can_prove(tirx.max(a, b) * 2 >= a + b, SB) + assert analyzer.can_prove(tirx.min(a, b) * 2 <= a + b, SB) + + +def test_z3_symbolic_bind_range_with_constraint(): + # Combine a symbolic range binding (x in [0, n)) with a constraint on the + # extent to derive a concrete bound on x. + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + n = tirx.Var("n", "int32") + analyzer.bind(x, tvm.ir.Range(0, n)) + with analyzer.constraint_scope(n <= 8): + assert not analyzer.can_prove(x < 8) + assert analyzer.can_prove(x < 8, SB) + + +def test_z3_equality_congruence(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + expr = implies(a == b, a * a == b * b) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_integer_strict_transitivity(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + # Over the integers, a < b and b < c implies a + 1 < c. + expr = implies(tirx.all(a < b, b < c), a + 1 < c) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_if_then_else_absolute_value(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + expr = tirx.if_then_else(x >= 0, x, -x) >= 0 + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_unsigned_non_negative(): + analyzer = Analyzer() + _require_z3(analyzer) + + u = tirx.Var("u", "uint32") + assert not analyzer.can_prove(u >= 0) + assert analyzer.can_prove(u >= 0, SB) + + +def test_z3_unsigned64_non_negative(): + # Exercises the special-cased uint64 range handling (UINT64_MAX bound). + analyzer = Analyzer() + _require_z3(analyzer) + + u = tirx.Var("u", "uint64") + assert not analyzer.can_prove(u >= 0) + assert analyzer.can_prove(u >= 0, SB) + + +def test_z3_int64_square_expansion(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int64") + b = tirx.Var("b", "int64") + expr = (a + b) * (a + b) >= a * a + b * b + with analyzer.constraint_scope(tirx.all(a >= 0, b >= 0)): + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_boolean_variable_reasoning(): + analyzer = Analyzer() + _require_z3(analyzer) + + p = tirx.Var("p", "bool") + q = tirx.Var("q", "bool") + expr = implies(tirx.And(p, q), tirx.Or(p, q)) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_not_equal_from_strict_less(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + expr = implies(x < y, tirx.NE(x, y)) + assert not analyzer.can_prove(expr) + assert analyzer.can_prove(expr, SB) + + +def test_z3_let_expression(): + analyzer = Analyzer() + _require_z3(analyzer) + + y = tirx.Var("y", "int32") + t = tirx.Var("t", "int32") + let = tirx.Let(t, y * 2, t) + assert not analyzer.can_prove(let == y * 2) + assert analyzer.can_prove(let == y * 2, SB) + + +def test_z3_cast_preserves_bounds(): + analyzer = Analyzer() + _require_z3(analyzer) + + s = tirx.Var("s", "int16") + widened = tirx.Cast("int32", s) + assert analyzer.can_prove(widened <= 32767, SB) + assert analyzer.can_prove(widened >= -32768, SB) + + +def test_z3_bitwise_and_mask_bound(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + analyzer.bind(x, tvm.ir.Range(0, 256)) + assert analyzer.can_prove(tirx.bitwise_and(x, tirx.IntImm("int32", 7)) < 8, SB) + + +def test_z3_bitwise_and_le_operand(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + analyzer.bind(x, tvm.ir.Range(0, 256)) + analyzer.bind(y, tvm.ir.Range(0, 256)) + # Bit-vector reasoning over two variables exceeds the default deterministic + # rlimit; lift it (0 == unlimited, still deterministic) for this proof. + analyzer.set_z3_rlimit(0) + assert analyzer.can_prove(tirx.bitwise_and(x, y) <= x, SB) + + +def test_z3_bitwise_or_ge_operand(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + analyzer.bind(x, tvm.ir.Range(0, 256)) + analyzer.bind(y, tvm.ir.Range(0, 256)) + analyzer.set_z3_rlimit(0) + assert analyzer.can_prove(tirx.bitwise_or(x, y) >= x, SB) + + +def test_z3_bitwise_xor_bound(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + y = tirx.Var("y", "int32") + analyzer.bind(x, tvm.ir.Range(0, 256)) + analyzer.bind(y, tvm.ir.Range(0, 256)) + analyzer.set_z3_rlimit(0) + assert analyzer.can_prove(tirx.bitwise_xor(x, y) < 256, SB) + + +def test_z3_bitwise_not_identity(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + analyzer.bind(x, tvm.ir.Range(0, 256)) + analyzer.set_z3_rlimit(0) + # Two's complement: ~x == -x - 1. + assert analyzer.can_prove(tirx.bitwise_not(x) == -x - 1, SB) + + +def test_z3_shift_right_halves(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + analyzer.bind(x, tvm.ir.Range(0, 256)) + analyzer.set_z3_rlimit(0) + # For non-negative x, (x >> 1) * 2 <= x. + assert analyzer.can_prove(tirx.shift_right(x, tirx.IntImm("int32", 1)) * 2 <= x, SB) + + +def test_z3_shift_left_lower_bound(): + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + n = tirx.Var("n", "int32") + # Keep operands small so the 32-bit left shift cannot overflow; then + # x << n == x * 2 ** n >= x for x >= 1. + analyzer.bind(x, tvm.ir.Range(1, 16)) + analyzer.bind(n, tvm.ir.Range(0, 4)) + # Bit-vector shift reasoning exceeds the default deterministic rlimit. + analyzer.set_z3_rlimit(0) + assert analyzer.can_prove(tirx.shift_left(x, n) >= x, SB) + + +# --------------------------------------------------------------------------- +# Soundness / negative tests (Z3 must NOT prove false predicates) +# --------------------------------------------------------------------------- + + +def test_z3_negative_unprovable_inequality(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + # a < b does not hold for arbitrary a, b. + assert not analyzer.can_prove(a < b, SB) + # a * a > a is false (e.g. a == 0). + assert not analyzer.can_prove(a * a > a, SB) + + +def test_z3_truncmod_can_be_negative(): + # Regression test for truncated div/mod semantics: TVM Div/Mod round toward + # zero, so truncmod(a, 4) can be negative. A solver that modeled them as + # Euclidean would unsoundly "prove" truncmod(a, 4) >= 0. + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + assert not analyzer.can_prove(tirx.truncmod(a, 4) >= 0, SB) + + +def test_z3_truncdiv_truncmod_identity(): + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + expr = tirx.truncdiv(a, b) * b + tirx.truncmod(a, b) == a + with analyzer.constraint_scope(b != 0): + assert analyzer.can_prove(expr, SB) + + +def test_z3_floormod_nested_identities(): + # Ported from TileLang's test_divmod. Here `%` is floormod: nested floormod + # by opposite-sign divisors collapses to the single-divisor result, while + # the mixed case does not. + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + assert not analyzer.can_prove(a % 2 % -2 - a % 2 == 0, SB) + assert analyzer.can_prove(a % -2 % 2 - a % 2 == 0, SB) + + +def test_z3_floormod_nonnegative(): + # In contrast to truncmod, floormod with a positive divisor is always in + # [0, divisor), which Z3 should be able to prove. + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + assert analyzer.can_prove(tirx.floormod(a, 4) >= 0, SB) + assert analyzer.can_prove(tirx.floormod(a, 4) < 4, SB) + + +def test_z3_shift_does_not_poison_solver(): + # Regression test: evaluating a shift expression must not add permanent + # assertions (such as `b >= 0` / `b < 64`) to the shared solver. Otherwise + # an unrelated, unbounded `b` would be wrongly provable to be < 100. + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + + # Touch a shift expression so the prover visits the shift amount `b`. + analyzer.can_prove(tirx.shift_left(a, b) >= 0, SB) + + # `b` is otherwise unconstrained, so this must remain unprovable. + assert not analyzer.can_prove(b < 100, SB) + assert not analyzer.can_prove(b >= 0, SB) + + +def test_z3_constraint_scope_is_popped(): + # Constraints entered through a scope must be removed once the scope exits, + # i.e. EnterConstraint's solver.push()/pop() must be balanced. + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + with analyzer.constraint_scope(x > 5): + assert analyzer.can_prove(x > 0, SB) + # The constraint is gone; x is unconstrained again. + assert not analyzer.can_prove(x > 0, SB) + + +def test_z3_opaque_call_is_safe(): + # An opaque/unsupported sub-expression is modeled as a fresh free variable. + # It must neither crash nor be provable on its own, yet still be usable as a + # constraint. + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + call = tirx.call_extern("int32", "foo", x) + assert not analyzer.can_prove(call > 0, SB) + with analyzer.constraint_scope(call > 0): + assert analyzer.can_prove(call > 0, SB) + assert not analyzer.can_prove(call > 0, SB) + + +def test_z3_shift_overflow_is_not_proven(): + # Z3 models fixed-width shifts via bit-vectors, so it correctly refuses to + # prove `x << n >= x` for an unbounded `x` (a large `x` overflows int32 and + # wraps to a negative value). This guards against unsound shift modeling. + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + n = tirx.Var("n", "int32") + analyzer.set_z3_rlimit(0) + expr = implies(tirx.all(x >= 1, n >= 0, n < 8), tirx.shift_left(x, n) >= x) + assert not analyzer.can_prove(expr, SB) + + +def test_z3_analyzers_are_isolated(): + # Analyzers share a thread-local Z3 context but own separate solvers, so + # constraints and bindings in one must never leak into another. + analyzer_a = Analyzer() + analyzer_b = Analyzer() + _require_z3(analyzer_a) + + x = tirx.Var("x", "int32") + with analyzer_a.constraint_scope(x > 100): + assert analyzer_a.can_prove(x > 50, SB) + assert not analyzer_b.can_prove(x > 50, SB) + + analyzer_c = Analyzer() + analyzer_d = Analyzer() + analyzer_c.bind(x, tvm.ir.Range(0, 10)) + assert analyzer_c.can_prove(x < 10, SB) + assert not analyzer_d.can_prove(x < 10, SB) + + +def test_z3_repeated_can_prove_is_consistent(): + # Repeated queries must be stateless: a CanProve call must not pollute the + # solver and change the result of a subsequent call. + analyzer = Analyzer() + _require_z3(analyzer) + + x = tirx.Var("x", "int32") + assert analyzer.can_prove(x > 0, SB) == analyzer.can_prove(x > 0, SB) + + analyzer.bind(x, tvm.ir.Range(5, 10)) + assert analyzer.can_prove(x >= 5, SB) + assert analyzer.can_prove(x >= 5, SB) + + +def test_z3_is_gated_behind_symbolic_bound(): + # The Z3 fallback must not run at the default strength. + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + expr = ((b - a) // c) * c + a <= b + with analyzer.constraint_scope(tirx.all(a > 0, b > 0, c > 0)): + assert not analyzer.can_prove(expr, ProofStrength.DEFAULT) + assert analyzer.can_prove(expr, SB) + + +# --------------------------------------------------------------------------- +# SMT-LIB2 export +# --------------------------------------------------------------------------- + + +def test_z3_smtlib2_roundtrip(): + z3 = pytest.importorskip("z3") + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + expr = ((b - a) // c) * c + a <= b + + solver = z3.Solver() + with analyzer.constraint_scope(tirx.all(a > 0, b > 0, c > 0)): + solver.from_string(analyzer.get_smtlib2(expr)) + assert solver.check() == z3.unsat + + +def test_z3_smtlib2_roundtrip_with_timeout(): + z3 = pytest.importorskip("z3") + analyzer = Analyzer() + _require_z3(analyzer) + + a = tirx.Var("a", "int32") + b = tirx.Var("b", "int32") + c = tirx.Var("c", "int32") + analyzer.set_z3_timeout_ms(1000) + + expr = implies(tirx.all(a > 0, b > 0, c > 0), ((b - a) // c) * c + a <= b) + solver = z3.Solver() + solver.from_string(analyzer.get_smtlib2(expr)) + assert solver.check() == z3.unsat + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py index 8eb3961ef8dd..2013a05210be 100644 --- a/tests/python/relax/test_utils.py +++ b/tests/python/relax/test_utils.py @@ -16,8 +16,6 @@ # under the License. # ruff: noqa: F841 -import re - import pytest import tvm @@ -143,7 +141,7 @@ def func_2(A: R.Tensor([16, 16], "float32")): with pytest.raises( ValueError, - match=re.escape(".body.blocks[0].bindings[0].value.op"), + match=r"body.*blocks.*bindings.*value.*op", ): assert_structural_equal(func_1, func_2) @@ -251,25 +249,13 @@ def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"): return recursive_lambda(n) - # The path to the first mismatch, which should appear within the - # error message. - mismatch_path = [ - "", - "body", - "blocks[0]", - "bindings[0]", - "value", - "body", - "blocks[0]", - "bindings[0]", - "value", - "true_branch", - "body", - "value", - "value", - ] - - with pytest.raises(ValueError, match=re.escape(".".join(mismatch_path))): + mismatch_path = ( + r"body.*blocks.*bindings.*value" + r".*body.*blocks.*bindings.*value" + r".*true_branch.*body.*value.*value" + ) + + with pytest.raises(ValueError, match=mismatch_path): tvm.ir.assert_structural_equal(func_a, func_b)