diff --git a/projects/PTv3/README.md b/projects/PTv3/README.md index 06ed1c204..3ebe061c2 100644 --- a/projects/PTv3/README.md +++ b/projects/PTv3/README.md @@ -58,11 +58,16 @@ Export the model: python projects/PTv3/tools/export.py --config-file projects/PTv3/configs/semseg-pt-v3m1-0-t4dataset.py --num-gpus 1 \ --options \ save_path=work_dirs/experiment \ - weight=work_dirs/ptv3/model/model_best.pth + weight=work_dirs/ptv3/model/model_best.pth \ + spconv_do_sort=False ``` which will generate a file called `ptv3.onnx` +- `spconv_do_sort` controls the ONNX attribute on `GetIndicePairsImplicitGemm`: + - `True` (default): export with sorting enabled (`do_sort=1`) + - `False`: export with sorting disabled (`do_sort=0`) + ## Reference - [Pointcept's PTv3](https://github.com/Pointcept/Pointcept) diff --git a/projects/PTv3/configs/_base_/default_runtime.py b/projects/PTv3/configs/_base_/default_runtime.py index cca5bd789..21048be37 100644 --- a/projects/PTv3/configs/_base_/default_runtime.py +++ b/projects/PTv3/configs/_base_/default_runtime.py @@ -23,6 +23,10 @@ mix_prob = 0 param_dicts = None # example: param_dicts = [dict(keyword="block", lr_scale=0.1)] +# Sparse pair-gen sort for ONNX export: +# - True (default): keep pair-mask argsort. +# - False : export GetIndicePairsImplicitGemm with do_sort=0. +spconv_do_sort = True # hook hooks = [ diff --git a/projects/PTv3/tools/export.py b/projects/PTv3/tools/export.py index ae24c23b3..aea6848f9 100644 --- a/projects/PTv3/tools/export.py +++ b/projects/PTv3/tools/export.py @@ -11,6 +11,8 @@ from models.utils.structure import Point, bit_length_tensor from torch.nn import functional as F +from projects.SparseConvolution.sparse_functional import set_do_sort + # NOTE: keep this import last; it overrides sparse conv registration for export. import SparseConvolution # isort: skip @@ -70,11 +72,19 @@ def forward( return pred_label, pred_probs +def _apply_spconv_do_sort(cfg) -> None: + value = bool(cfg.get("spconv_do_sort", True)) + + set_do_sort(value) + print("[PTv3][export] spconv_do_sort=" f"{value} (baked into GetIndicePairsImplicitGemm.do_sort_i at ONNX export)") + + def main(): args = default_argument_parser().parse_args() cfg = default_config_parser(args.config_file, args.options) cfg = default_setup(cfg) + _apply_spconv_do_sort(cfg) cfg.num_worker = 1 cfg.num_worker_per_gpu = 1 diff --git a/projects/SparseConvolution/sparse_functional.py b/projects/SparseConvolution/sparse_functional.py index efb72f7e1..528f83a72 100644 --- a/projects/SparseConvolution/sparse_functional.py +++ b/projects/SparseConvolution/sparse_functional.py @@ -6,7 +6,7 @@ from cumm import tensorview as tv from spconv import constants from spconv.algo import CONV_CPP -from spconv.constants import SPCONV_DO_SORT, SPCONV_USE_DIRECT_TABLE, AllocKeys +from spconv.constants import SPCONV_USE_DIRECT_TABLE, AllocKeys from spconv.core import ConvAlgo from spconv.core_cc.csrc.sparse.all import SpconvOps from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps @@ -18,6 +18,14 @@ from torch.onnx.symbolic_helper import _get_tensor_sizes +def set_do_sort(value: bool) -> None: + """Set `do_sort` used by GetIndicePairsImplicitGemm symbolic/forward. + + Kept as a module-level API so existing callers do not need to change. + """ + GetIndicePairsImplicitGemm.set_do_sort(value) + + class GetIndicePairs(Function): @staticmethod @@ -212,6 +220,12 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> tuple: class GetIndicePairsImplicitGemm(Function): + # Controls `do_sort` in both ONNX symbolic and PyTorch forward paths. + _do_sort: bool = True + + @classmethod + def set_do_sort(cls, value: bool) -> None: + cls._do_sort = bool(value) @staticmethod def symbolic( @@ -245,6 +259,7 @@ def symbolic( subm_i=subm, transpose_i=transpose, is_train_i=is_train, + do_sort_i=int(GetIndicePairsImplicitGemm._do_sort), outputs=5, ) indices_shape = _get_tensor_sizes(indices) @@ -301,7 +316,7 @@ def forward( num_out_act_bound: int = -1 direct_table: bool = SPCONV_USE_DIRECT_TABLE - do_sort = SPCONV_DO_SORT + do_sort = GetIndicePairsImplicitGemm._do_sort stream = get_current_stream()