Sparse tensor with support for both torchsparse and spconv backends. Parameters: - feats (torch.Tensor): Features of the sparse tensor. - coords (torch.Tensor): Coordinates of the sparse tensor. - shape (torch.Size): Shape of the sparse tensor. - layout (List[slice]): L
| 341 | |
| 342 | |
| 343 | class SparseTensor(VarLenTensor): |
| 344 | """ |
| 345 | Sparse tensor with support for both torchsparse and spconv backends. |
| 346 | |
| 347 | Parameters: |
| 348 | - feats (torch.Tensor): Features of the sparse tensor. |
| 349 | - coords (torch.Tensor): Coordinates of the sparse tensor. |
| 350 | - shape (torch.Size): Shape of the sparse tensor. |
| 351 | - layout (List[slice]): Layout of the sparse tensor for each batch |
| 352 | - data (SparseTensorData): Sparse tensor data used for convolusion |
| 353 | |
| 354 | NOTE: |
| 355 | - Data corresponding to a same batch should be contiguous. |
| 356 | - Coords should be in [0, 1023] |
| 357 | """ |
| 358 | SparseTensorData = None |
| 359 | |
| 360 | @overload |
| 361 | def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ... |
| 362 | |
| 363 | @overload |
| 364 | def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ... |
| 365 | |
| 366 | def __init__(self, *args, **kwargs): |
| 367 | # Lazy import of sparse tensor backend |
| 368 | if self.SparseTensorData is None: |
| 369 | import importlib |
| 370 | if config.CONV == 'torchsparse': |
| 371 | self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor |
| 372 | elif config.CONV == 'spconv': |
| 373 | self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor |
| 374 | |
| 375 | method_id = 0 |
| 376 | if len(args) != 0: |
| 377 | method_id = 0 if isinstance(args[0], torch.Tensor) else 1 |
| 378 | else: |
| 379 | method_id = 1 if 'data' in kwargs else 0 |
| 380 | |
| 381 | if method_id == 0: |
| 382 | feats, coords, shape = args + (None,) * (3 - len(args)) |
| 383 | if 'feats' in kwargs: |
| 384 | feats = kwargs['feats'] |
| 385 | del kwargs['feats'] |
| 386 | if 'coords' in kwargs: |
| 387 | coords = kwargs['coords'] |
| 388 | del kwargs['coords'] |
| 389 | if 'shape' in kwargs: |
| 390 | shape = kwargs['shape'] |
| 391 | del kwargs['shape'] |
| 392 | |
| 393 | if config.CONV == 'torchsparse': |
| 394 | self.data = self.SparseTensorData(feats, coords, **kwargs) |
| 395 | elif config.CONV == 'spconv': |
| 396 | spatial_shape = list(coords.max(0)[0] + 1) |
| 397 | self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs) |
| 398 | self.data._features = feats |
| 399 | else: |
| 400 | self.data = { |
no outgoing calls
no test coverage detected