Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions lighthouse/ingress/torch/importer.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import importlib
import importlib.util
from pathlib import Path
from typing import Iterable, Mapping

from lighthouse.ingress.torch.utils import (
load_and_run_callable,
maybe_load_and_run_callable,
)
from lighthouse.utils.importer import import_python_module

try:
import torch
Expand Down Expand Up @@ -124,11 +123,8 @@ def get_inputs():
"""
if isinstance(filepath, str):
filepath = Path(filepath)
module_name = filepath.stem

spec = importlib.util.spec_from_file_location(module_name, filepath)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
module = import_python_module(filepath)

model = getattr(module, model_class_name, None)
if model is None:
Expand Down
2 changes: 1 addition & 1 deletion lighthouse/pipeline/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from mlir import ir
import lighthouse.pipeline.stage as lhs
from lighthouse.pipeline.helper import import_mlir_module
from lighthouse.pipeline.descriptor import PipelineDescriptor, Descriptor
from lighthouse.utils.importer import import_mlir_module
import lighthouse.dialects as lh_dialects


Expand Down
12 changes: 0 additions & 12 deletions lighthouse/pipeline/helper.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,8 @@
import os

from mlir import ir
from mlir.dialects import transform
from mlir.dialects.transform import structured


def import_mlir_module(path: str, context: ir.Context) -> ir.Module:
"""Import an MLIR text file into an MLIR module"""
if path is None:
raise ValueError("Path to the module must be provided.")
if not os.path.exists(path):
raise ValueError(f"Path to the module does not exist: {path}")
with open(path, "r") as f:
return ir.Module.parse(f.read(), context=context)


def apply_registered_pass(*args, **kwargs):
"""Utility function to add a bundle of passes to a Transform Schedule"""
return transform.apply_registered_pass(transform.AnyOpType.get(), *args, **kwargs)
Expand Down
12 changes: 2 additions & 10 deletions lighthouse/pipeline/stage.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from abc import abstractmethod
import importlib
from enum import Enum
from pathlib import Path
import os

from mlir import ir
from mlir.passmanager import PassManager
from mlir.dialects import transform
from lighthouse.pipeline.helper import import_mlir_module
from lighthouse.pipeline.descriptor import Descriptor, PipelineDescriptor
from lighthouse.utils.importer import import_mlir_module, import_python_module


class Pass:
Expand Down Expand Up @@ -170,12 +167,7 @@ def __init__(self, transform: Transform | ir.Module, context: ir.Context):
elif transform.type == Transform.Type.Python:
# For Python transforms, we expect the file to define a function
# that takes an MLIR module and returns a transformed MLIR module.
module_name = Path(os.path.basename(transform.filename)).stem
spec = importlib.util.spec_from_file_location(
module_name, transform.filename
)
transform_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(transform_module)
transform_module = import_python_module(transform.filename)
Comment thread
rengolin marked this conversation as resolved.
if not hasattr(transform_module, transform.generator):
raise ValueError(
f"Transform module '{transform.filename}' does not define a '{transform.generator}' generator function."
Expand Down
79 changes: 79 additions & 0 deletions lighthouse/utils/importer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import importlib
import importlib.util
import os
import sys
from functools import lru_cache
from pathlib import Path
from types import ModuleType

from mlir import ir


def import_mlir_module(path: str, context: ir.Context) -> ir.Module:
"""
Import an MLIR text file into an MLIR module

Args:
path: Path to the MLIR file
context: MLIR context
Returns:
MLIR module
"""
if path is None:
raise ValueError("Path to the module must be provided.")
if not os.path.exists(path):
raise ValueError(f"Path to the module does not exist: {path}")
with open(path, "r") as f:
return ir.Module.parse(f.read(), context=context)


@lru_cache(maxsize=None)
def _resolve_package(directory: Path) -> tuple[str, str]:
"""
Resolve the enclosing package for a directory.

Results are cached per directory so the walk runs at most once
for each directory.
"""
package_parts: list[str] = []
parent = directory
while (parent / "__init__.py").exists():
package_parts.insert(0, parent.name)
parent = parent.parent
return str(parent), ".".join(package_parts)


def import_python_module(path: str) -> ModuleType:
"""
Import a Python module from a file.

Args:
path: Path to the Python file
Returns:
Imported Python module
"""
if path is None:
raise ValueError("Path to the module must be provided.")
if not os.path.exists(path):
raise ValueError(f"Path to the module does not exist: {path}")

filepath = Path(path).resolve()

# Resolve the enclosing package to enable relative imports.
package_root, dotted_prefix = _resolve_package(filepath.parent)

module_name = filepath.stem
if dotted_prefix:
# The file is part of a package: build its dotted name and make sure
# the package root is importable so relative imports resolve.
qualified_name = f"{dotted_prefix}.{module_name}"
if package_root not in sys.path:
sys.path.insert(0, package_root)
return importlib.import_module(qualified_name)

# Standalone file: load it directly from its location.
spec = importlib.util.spec_from_file_location(module_name, str(filepath))
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return module
Loading