Finds all spconv keys that need to have weight's transposed
(model: nn.Module, prefix="")
| 9 | |
| 10 | |
| 11 | def find_all_spconv_keys(model: nn.Module, prefix="") -> Set[str]: |
| 12 | """ |
| 13 | Finds all spconv keys that need to have weight's transposed |
| 14 | """ |
| 15 | found_keys: Set[str] = set() |
| 16 | for name, child in model.named_children(): |
| 17 | new_prefix = f"{prefix}.{name}" if prefix != "" else name |
| 18 | |
| 19 | if isinstance(child, spconv.conv.SparseConvolution): |
| 20 | new_prefix = f"{new_prefix}.weight" |
| 21 | found_keys.add(new_prefix) |
| 22 | |
| 23 | found_keys.update(find_all_spconv_keys(child, prefix=new_prefix)) |
| 24 | |
| 25 | return found_keys |
| 26 | |
| 27 | |
| 28 | def replace_feature(out, new_features): |
no test coverage detected