MCPcopy
hub / github.com/Pointcept/PointTransformerV3 / PointSequential

Class PointSequential

model.py:186–252  ·  view source on GitHub ↗

r"""A sequential container. Modules will be added to it in the order they are passed in the constructor. Alternatively, an ordered dict of modules can also be passed in.

Source from the content-addressed store, hash-verified

184
185
186class PointSequential(PointModule):
187 r"""A sequential container.
188 Modules will be added to it in the order they are passed in the constructor.
189 Alternatively, an ordered dict of modules can also be passed in.
190 """
191
192 def __init__(self, *args, **kwargs):
193 super().__init__()
194 if len(args) == 1 and isinstance(args[0], OrderedDict):
195 for key, module in args[0].items():
196 self.add_module(key, module)
197 else:
198 for idx, module in enumerate(args):
199 self.add_module(str(idx), module)
200 for name, module in kwargs.items():
201 if sys.version_info < (3, 6):
202 raise ValueError("kwargs only supported in py36+")
203 if name in self._modules:
204 raise ValueError("name exists.")
205 self.add_module(name, module)
206
207 def __getitem__(self, idx):
208 if not (-len(self) <= idx < len(self)):
209 raise IndexError("index {} is out of range".format(idx))
210 if idx < 0:
211 idx += len(self)
212 it = iter(self._modules.values())
213 for i in range(idx):
214 next(it)
215 return next(it)
216
217 def __len__(self):
218 return len(self._modules)
219
220 def add(self, module, name=None):
221 if name is None:
222 name = str(len(self._modules))
223 if name in self._modules:
224 raise KeyError("name exists")
225 self.add_module(name, module)
226
227 def forward(self, input):
228 for k, module in self._modules.items():
229 # Point module
230 if isinstance(module, PointModule):
231 input = module(input)
232 # Spconv module
233 elif spconv.modules.is_spconv_module(module):
234 if isinstance(input, Point):
235 input.sparse_conv_feat = module(input.sparse_conv_feat)
236 input.feat = input.sparse_conv_feat.features
237 else:
238 input = module(input)
239 # PyTorch module
240 else:
241 if isinstance(input, Point):
242 input.feat = module(input.feat)
243 if "sparse_conv_feat" in input.keys():

Callers 5

__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85
__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected