MCPcopy Index your code
hub / github.com/TencentARC/Pixal3D / SparseTensor

Class SparseTensor

pixal3d/modules/sparse/basic.py:343–795  ·  view source on GitHub ↗

Sparse tensor with support for both torchsparse and spconv backends. Parameters: - feats (torch.Tensor): Features of the sparse tensor. - coords (torch.Tensor): Coordinates of the sparse tensor. - shape (torch.Size): Shape of the sparse tensor. - layout (List[slice]): L

Source from the content-addressed store, hash-verified

341
342
343class SparseTensor(VarLenTensor):
344 """
345 Sparse tensor with support for both torchsparse and spconv backends.
346
347 Parameters:
348 - feats (torch.Tensor): Features of the sparse tensor.
349 - coords (torch.Tensor): Coordinates of the sparse tensor.
350 - shape (torch.Size): Shape of the sparse tensor.
351 - layout (List[slice]): Layout of the sparse tensor for each batch
352 - data (SparseTensorData): Sparse tensor data used for convolusion
353
354 NOTE:
355 - Data corresponding to a same batch should be contiguous.
356 - Coords should be in [0, 1023]
357 """
358 SparseTensorData = None
359
360 @overload
361 def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ...
362
363 @overload
364 def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ...
365
366 def __init__(self, *args, **kwargs):
367 # Lazy import of sparse tensor backend
368 if self.SparseTensorData is None:
369 import importlib
370 if config.CONV == 'torchsparse':
371 self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor
372 elif config.CONV == 'spconv':
373 self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor
374
375 method_id = 0
376 if len(args) != 0:
377 method_id = 0 if isinstance(args[0], torch.Tensor) else 1
378 else:
379 method_id = 1 if 'data' in kwargs else 0
380
381 if method_id == 0:
382 feats, coords, shape = args + (None,) * (3 - len(args))
383 if 'feats' in kwargs:
384 feats = kwargs['feats']
385 del kwargs['feats']
386 if 'coords' in kwargs:
387 coords = kwargs['coords']
388 del kwargs['coords']
389 if 'shape' in kwargs:
390 shape = kwargs['shape']
391 del kwargs['shape']
392
393 if config.CONV == 'torchsparse':
394 self.data = self.SparseTensorData(feats, coords, **kwargs)
395 elif config.CONV == 'spconv':
396 spatial_shape = list(coords.max(0)[0] + 1)
397 self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs)
398 self.data._features = feats
399 else:
400 self.data = {

Callers 15

unpack_stateFunction · 0.90
from_tensor_listMethod · 0.85
replaceMethod · 0.85
fullMethod · 0.85
__getitem__Method · 0.85
sparse_catFunction · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
sparse_conv3d_forwardFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected