diff --git a/pcdet/utils/spconv_utils.py b/pcdet/utils/spconv_utils.py new file mode 100644 index 0000000..c38f899 --- /dev/null +++ b/pcdet/utils/spconv_utils.py @@ -0,0 +1,38 @@ +from typing import Set + +import spconv +if float(spconv.__version__[2:]) >= 2.2: + spconv.constants.SPCONV_USE_DIRECT_TABLE = False + +try: + import spconv.pytorch as spconv +except: + import spconv as spconv + +import torch.nn as nn + + +def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]: + """ + Finds all spconv keys that need to have weight's transposed + """ + found_keys: Set[str] = set() + for name, child in model.named_children(): + new_prefix = f"{prefix}.{name}" if prefix != "" else name + + if isinstance(child, spconv.conv.SparseConvolution): + new_prefix = f"{new_prefix}.weight" + found_keys.add(new_prefix) + + found_keys.update(find_all_spconv_keys(child, prefix=new_prefix)) + + return found_keys + + +def replace_feature(out, new_features): + if "replace_feature" in out.__dir__(): + # spconv 2.x behaviour + return out.replace_feature(new_features) + else: + out.features = new_features + return out