r"""A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in.
| 184 | |
| 185 | |
| 186 | class PointSequential(PointModule): |
| 187 | r"""A sequential container. |
| 188 | Modules will be added to it in the order they are passed in the constructor. |
| 189 | Alternatively, an ordered dict of modules can also be passed in. |
| 190 | """ |
| 191 | |
| 192 | def __init__(self, *args, **kwargs): |
| 193 | super().__init__() |
| 194 | if len(args) == 1 and isinstance(args[0], OrderedDict): |
| 195 | for key, module in args[0].items(): |
| 196 | self.add_module(key, module) |
| 197 | else: |
| 198 | for idx, module in enumerate(args): |
| 199 | self.add_module(str(idx), module) |
| 200 | for name, module in kwargs.items(): |
| 201 | if sys.version_info < (3, 6): |
| 202 | raise ValueError("kwargs only supported in py36+") |
| 203 | if name in self._modules: |
| 204 | raise ValueError("name exists.") |
| 205 | self.add_module(name, module) |
| 206 | |
| 207 | def __getitem__(self, idx): |
| 208 | if not (-len(self) <= idx < len(self)): |
| 209 | raise IndexError("index {} is out of range".format(idx)) |
| 210 | if idx < 0: |
| 211 | idx += len(self) |
| 212 | it = iter(self._modules.values()) |
| 213 | for i in range(idx): |
| 214 | next(it) |
| 215 | return next(it) |
| 216 | |
| 217 | def __len__(self): |
| 218 | return len(self._modules) |
| 219 | |
| 220 | def add(self, module, name=None): |
| 221 | if name is None: |
| 222 | name = str(len(self._modules)) |
| 223 | if name in self._modules: |
| 224 | raise KeyError("name exists") |
| 225 | self.add_module(name, module) |
| 226 | |
| 227 | def forward(self, input): |
| 228 | for k, module in self._modules.items(): |
| 229 | # Point module |
| 230 | if isinstance(module, PointModule): |
| 231 | input = module(input) |
| 232 | # Spconv module |
| 233 | elif spconv.modules.is_spconv_module(module): |
| 234 | if isinstance(input, Point): |
| 235 | input.sparse_conv_feat = module(input.sparse_conv_feat) |
| 236 | input.feat = input.sparse_conv_feat.features |
| 237 | else: |
| 238 | input = module(input) |
| 239 | # PyTorch module |
| 240 | else: |
| 241 | if isinstance(input, Point): |
| 242 | input.feat = module(input.feat) |
| 243 | if "sparse_conv_feat" in input.keys(): |