Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion projects/PTv3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,16 @@ Export the model:
python projects/PTv3/tools/export.py --config-file projects/PTv3/configs/semseg-pt-v3m1-0-t4dataset.py --num-gpus 1 \
--options \
save_path=work_dirs/experiment \
weight=work_dirs/ptv3/model/model_best.pth
weight=work_dirs/ptv3/model/model_best.pth \
spconv_do_sort=False
```

which will generate a file called `ptv3.onnx`

- `spconv_do_sort` controls the ONNX attribute on `GetIndicePairsImplicitGemm`:
- `True` (default): export with sorting enabled (`do_sort=1`)
- `False`: export with sorting disabled (`do_sort=0`)

## Reference

- [Pointcept's PTv3](https://github.com/Pointcept/Pointcept)
4 changes: 4 additions & 0 deletions projects/PTv3/configs/_base_/default_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@

mix_prob = 0
param_dicts = None # example: param_dicts = [dict(keyword="block", lr_scale=0.1)]
# Sparse pair-gen sort for ONNX export:
# - True (default): keep pair-mask argsort.
# - False : export GetIndicePairsImplicitGemm with do_sort=0.
spconv_do_sort = True

# hook
hooks = [
Expand Down
10 changes: 10 additions & 0 deletions projects/PTv3/tools/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from models.utils.structure import Point, bit_length_tensor
from torch.nn import functional as F

from projects.SparseConvolution.sparse_functional import set_do_sort

# NOTE: keep this import last; it overrides sparse conv registration for export.
import SparseConvolution # isort: skip

Expand Down Expand Up @@ -70,11 +72,19 @@ def forward(
return pred_label, pred_probs


def _apply_spconv_do_sort(cfg) -> None:
value = bool(cfg.get("spconv_do_sort", True))

set_do_sort(value)
print("[PTv3][export] spconv_do_sort=" f"{value} (baked into GetIndicePairsImplicitGemm.do_sort_i at ONNX export)")


def main():
args = default_argument_parser().parse_args()
cfg = default_config_parser(args.config_file, args.options)

cfg = default_setup(cfg)
_apply_spconv_do_sort(cfg)
cfg.num_worker = 1
cfg.num_worker_per_gpu = 1

Expand Down
19 changes: 17 additions & 2 deletions projects/SparseConvolution/sparse_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from cumm import tensorview as tv
from spconv import constants
from spconv.algo import CONV_CPP
from spconv.constants import SPCONV_DO_SORT, SPCONV_USE_DIRECT_TABLE, AllocKeys
from spconv.constants import SPCONV_USE_DIRECT_TABLE, AllocKeys
from spconv.core import ConvAlgo
from spconv.core_cc.csrc.sparse.all import SpconvOps
from spconv.core_cc.csrc.sparse.convops.spops import ConvGemmOps
Expand All @@ -18,6 +18,14 @@
from torch.onnx.symbolic_helper import _get_tensor_sizes


def set_do_sort(value: bool) -> None:
"""Set `do_sort` used by GetIndicePairsImplicitGemm symbolic/forward.

Kept as a module-level API so existing callers do not need to change.
"""
GetIndicePairsImplicitGemm.set_do_sort(value)


class GetIndicePairs(Function):

@staticmethod
Expand Down Expand Up @@ -212,6 +220,12 @@ def backward(ctx: Any, grad_output: torch.Tensor) -> tuple:


class GetIndicePairsImplicitGemm(Function):
# Controls `do_sort` in both ONNX symbolic and PyTorch forward paths.
_do_sort: bool = True

@classmethod
def set_do_sort(cls, value: bool) -> None:
cls._do_sort = bool(value)

@staticmethod
def symbolic(
Expand Down Expand Up @@ -245,6 +259,7 @@ def symbolic(
subm_i=subm,
transpose_i=transpose,
is_train_i=is_train,
do_sort_i=int(GetIndicePairsImplicitGemm._do_sort),
outputs=5,
)
indices_shape = _get_tensor_sizes(indices)
Expand Down Expand Up @@ -301,7 +316,7 @@ def forward(

num_out_act_bound: int = -1
direct_table: bool = SPCONV_USE_DIRECT_TABLE
do_sort = SPCONV_DO_SORT
do_sort = GetIndicePairsImplicitGemm._do_sort

stream = get_current_stream()

Expand Down