Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,11 @@ def wrapper(exctype, value, trbk):

sys.excepthook = tvm_wrap_excepthook(sys.excepthook)

# Autoload out-of-tree backends registered under the ``tvm.backends`` entry
# point group. Runs last, after the core runtime and the tvm namespace are
# fully initialized, so an extension can safely register into ``tvm.*`` and
# load extra libraries. Imported lazily here to avoid any import-cycle risk.
from ._autoload_backends import _autoload_backends
# Autoload backend runtime libraries and out-of-tree backends registered under
# the ``tvm.backends`` entry point group. Runs last, after the core runtime and
# the tvm namespace are fully initialized, so an extension can safely register
# into ``tvm.*`` and load extra libraries.
from .backend._autoload_backends import _autoload_backends

backend.autoload_backend_libs()
_autoload_backends()
26 changes: 3 additions & 23 deletions python/tvm/_autoload_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Autoload out-of-tree backends registered via ``tvm.backends`` entry points.
"""Compatibility route for backend autoload infrastructure.

Out-of-tree extensions opt into being loaded automatically at ``import tvm``
time by declaring an entry point in the ``tvm.backends`` group::
Expand All @@ -25,26 +25,6 @@
Autoload can be disabled via ``TVM_DEVICE_BACKEND_AUTOLOAD=0``.
"""

import os
import warnings
from importlib.metadata import entry_points
from .backend._autoload_backends import _autoload_backends

# Guard so autoload runs at most once per process, even if invoked again.
_AUTO_LOAD_DONE = False


def _autoload_backends():
"""Discover and invoke out-of-tree backends registered via entry points."""
global _AUTO_LOAD_DONE
if _AUTO_LOAD_DONE:
return
_AUTO_LOAD_DONE = True

if os.environ.get("TVM_DEVICE_BACKEND_AUTOLOAD", "1") == "0":
return

for entry_pt in entry_points(group="tvm.backends"):
try:
entry_pt.load()()
except Exception as e: # pylint: disable=broad-except
warnings.warn(f"Failed to autoload tvm backend '{entry_pt.name}': {e}")
__all__ = ["_autoload_backends"]
23 changes: 3 additions & 20 deletions python/tvm/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,10 @@
from pkgutil import extend_path
from typing import Any

from ._autoload_backends import autoload_backend_libs, load_all

__path__ = extend_path(__path__, __name__) # type: ignore[name-defined]

_BUILTIN_BACKENDS = (
"cuda",
"metal",
"rocm",
"trn",
"opencl",
"vulkan",
"webgpu",
"hexagon",
"adreno",
)
_LOADED_BACKENDS: dict[str, Any] = {}


Expand Down Expand Up @@ -192,18 +183,10 @@ def load(name: str) -> None:
return None


def load_all() -> None:
"""Load all in-tree backend Python hooks."""

for name in _BUILTIN_BACKENDS:
load(name)
return None


def is_loaded(name: str) -> bool:
"""Return whether a backend has been loaded."""

return name in _LOADED_BACKENDS


__all__ = ["is_loaded", "load", "load_all"]
__all__ = ["autoload_backend_libs", "is_loaded", "load", "load_all"]
111 changes: 111 additions & 0 deletions python/tvm/backend/_autoload_backends.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Autoload backend libraries and Python backend registration hooks."""

from __future__ import annotations

import os
import warnings
from importlib import import_module
from importlib.metadata import entry_points
from pathlib import Path
from typing import Any

from tvm_ffi.libinfo import load_lib_ctypes

_BUILTIN_BACKENDS = (
"cuda",
"metal",
"rocm",
"trn",
"opencl",
"vulkan",
"webgpu",
"hexagon",
"adreno",
)
_LEGACY_RUNTIME_LIBS_WITHOUT_BACKEND_PACKAGE = ("extra",)

# Guard so autoload runs at most once per process, even if invoked again.
_BACKEND_LIBS_LOADED = False
_AUTO_LOAD_DONE = False


def autoload_backend_libs(loaded_libs: dict[str, Any] | None = None) -> None:
"""Load each known backend runtime DSO into the process-global symbol namespace."""
global _BACKEND_LIBS_LOADED
if _BACKEND_LIBS_LOADED:
return

if loaded_libs is None:
from tvm.base import _LOADED_LIBS # pylint: disable=import-outside-toplevel

loaded_libs = _LOADED_LIBS

runtime_lib = loaded_libs.get("tvm_runtime")
if runtime_lib is None:
return

_BACKEND_LIBS_LOADED = True

runtime_dir = Path(runtime_lib._name).resolve().parent
for runtime_lib_name in _backend_runtime_lib_names():
target_name = f"tvm_runtime_{runtime_lib_name}"
try:
loaded_libs[target_name] = load_lib_ctypes(
package="tvm",
target_name=target_name,
mode="RTLD_GLOBAL",
extra_lib_paths=[runtime_dir],
)
except (OSError, FileNotFoundError, RuntimeError):
pass


def _backend_runtime_lib_names() -> tuple[str, ...]:
runtime_libs = []
for backend in _BUILTIN_BACKENDS:
module = import_module(f"tvm.backend.{backend}")
runtime_libs.extend(getattr(module, "RUNTIME_LIBS", ()))
runtime_libs.extend(_LEGACY_RUNTIME_LIBS_WITHOUT_BACKEND_PACKAGE)
return tuple(runtime_libs)


def load_all() -> None:
"""Load all in-tree backend Python hooks."""
from . import load # pylint: disable=import-outside-toplevel

for name in _BUILTIN_BACKENDS:
load(name)
return None


def _autoload_backends() -> None:
"""Discover and invoke out-of-tree backends registered via entry points."""
global _AUTO_LOAD_DONE
if _AUTO_LOAD_DONE:
return
_AUTO_LOAD_DONE = True

if os.environ.get("TVM_DEVICE_BACKEND_AUTOLOAD", "1") == "0":
return

for entry_pt in entry_points(group="tvm.backends"):
try:
entry_pt.load()()
except Exception as e: # pylint: disable=broad-except
warnings.warn(f"Failed to autoload tvm backend '{entry_pt.name}': {e}")
18 changes: 18 additions & 0 deletions python/tvm/backend/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,29 @@
from importlib import import_module

_LAZY_SUBMODULES = {"lang", "op", "operator", "script", "target_tags"}
RUNTIME_LIBS = ("cuda",)


def _detect_target_from_device(dev):
from tvm.target import Target # pylint: disable=import-outside-toplevel

return Target(
{
"kind": "cuda",
"max_shared_memory_per_block": dev.max_shared_memory_per_block,
"max_threads_per_block": dev.max_threads_per_block,
"thread_warp_size": dev.warp_size,
"arch": "sm_" + dev.compute_version.replace(".", ""),
}
)


def register_backend():
"""Register CUDA-owned Python semantics."""
from tvm.target.detect_target import register_device_target_detector
from tvm.tirx.script.builder import ir as builder_ir # pylint: disable=import-outside-toplevel

register_device_target_detector("cuda", _detect_target_from_device)
for name, namespace in script_namespaces().items():
builder_ir.register_script_namespace(name, namespace)

Expand Down Expand Up @@ -64,6 +81,7 @@ def __getattr__(name: str):
"op",
"operator",
"register_backend",
"RUNTIME_LIBS",
"script",
"script_namespace",
"script_namespaces",
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/backend/hexagon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from importlib import import_module

_LAZY_SUBMODULES = {"target_tags"}
RUNTIME_LIBS = ("hexagon",)


def register_backend():
Expand All @@ -32,4 +33,4 @@ def __getattr__(name: str):
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


__all__ = ["register_backend", "target_tags"]
__all__ = ["register_backend", "RUNTIME_LIBS", "target_tags"]
17 changes: 17 additions & 0 deletions python/tvm/backend/metal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,28 @@
from importlib import import_module

_LAZY_SUBMODULES = {"op", "script", "target_tags"}
RUNTIME_LIBS = ("metal",)


def _detect_target_from_device(dev):
from tvm.target import Target # pylint: disable=import-outside-toplevel

return Target(
{
"kind": "metal",
"max_shared_memory_per_block": 32768,
"max_threads_per_block": dev.max_threads_per_block,
"thread_warp_size": dev.warp_size,
}
)


def register_backend():
"""Register Metal-owned Python semantics."""
from tvm.target.detect_target import register_device_target_detector
from tvm.tirx.script.builder import ir as builder_ir # pylint: disable=import-outside-toplevel

register_device_target_detector("metal", _detect_target_from_device)
for name, namespace in script_namespaces().items():
builder_ir.register_script_namespace(name, namespace)
import_module(f"{__name__}.target_tags")
Expand All @@ -51,6 +67,7 @@ def __getattr__(name: str):
__all__ = [
"op",
"register_backend",
"RUNTIME_LIBS",
"script",
"script_namespace",
"script_namespaces",
Expand Down
20 changes: 19 additions & 1 deletion python/tvm/backend/opencl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,28 @@
# under the License.
"""OpenCL-owned backend hooks."""

RUNTIME_LIBS = ("opencl",)


def _detect_target_from_device(dev):
from tvm.target import Target # pylint: disable=import-outside-toplevel

return Target(
{
"kind": "opencl",
"max_shared_memory_per_block": dev.max_shared_memory_per_block,
"max_threads_per_block": dev.max_threads_per_block,
"thread_warp_size": dev.warp_size,
}
)


def register_backend():
"""Register OpenCL-owned Python semantics."""
from tvm.target.detect_target import register_device_target_detector

register_device_target_detector("opencl", _detect_target_from_device)
return None


__all__ = ["register_backend"]
__all__ = ["register_backend", "RUNTIME_LIBS"]
21 changes: 20 additions & 1 deletion python/tvm/backend/rocm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,29 @@
# under the License.
"""ROCm-owned TIRx modules."""

RUNTIME_LIBS = ("rocm",)


def _detect_target_from_device(dev):
from tvm.target import Target # pylint: disable=import-outside-toplevel

return Target(
{
"kind": "rocm",
"mtriple": "amdgcn-amd-amdhsa-hcc",
"max_shared_memory_per_block": dev.max_shared_memory_per_block,
"max_threads_per_block": dev.max_threads_per_block,
"thread_warp_size": dev.warp_size,
}
)


def register_backend():
"""Register ROCm-owned Python semantics."""
from tvm.target.detect_target import register_device_target_detector

register_device_target_detector("rocm", _detect_target_from_device)
return None


__all__ = ["register_backend"]
__all__ = ["register_backend", "RUNTIME_LIBS"]
31 changes: 30 additions & 1 deletion python/tvm/backend/vulkan/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,39 @@
# under the License.
"""Vulkan-owned backend hooks."""

RUNTIME_LIBS = ("vulkan",)


def _detect_target_from_device(dev):
from tvm import get_global_func # pylint: disable=import-outside-toplevel
from tvm.target import Target # pylint: disable=import-outside-toplevel

f_get_target_property = get_global_func("device_api.vulkan.get_target_property")
return Target(
{
"kind": "vulkan",
"max_threads_per_block": dev.max_threads_per_block,
"max_shared_memory_per_block": dev.max_shared_memory_per_block,
"thread_warp_size": dev.warp_size,
"supports_float16": f_get_target_property(dev, "supports_float16"),
"supports_int8": f_get_target_property(dev, "supports_int8"),
"supports_int16": f_get_target_property(dev, "supports_int16"),
"supports_int64": f_get_target_property(dev, "supports_int64"),
"supports_8bit_buffer": f_get_target_property(dev, "supports_8bit_buffer"),
"supports_16bit_buffer": f_get_target_property(dev, "supports_16bit_buffer"),
"supports_storage_buffer_storage_class": f_get_target_property(
dev, "supports_storage_buffer_storage_class"
),
}
)


def register_backend():
"""Register Vulkan-owned Python semantics."""
from tvm.target.detect_target import register_device_target_detector

register_device_target_detector("vulkan", _detect_target_from_device)
return None


__all__ = ["register_backend"]
__all__ = ["register_backend", "RUNTIME_LIBS"]
Loading
Loading