| 636 | return sparse_unbind(self, dim) |
| 637 | |
| 638 | def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': |
| 639 | if config.CONV == 'torchsparse': |
| 640 | new_data = self.SparseTensorData( |
| 641 | feats=feats, |
| 642 | coords=self.data.coords if coords is None else coords, |
| 643 | stride=self.data.stride, |
| 644 | spatial_range=self.data.spatial_range, |
| 645 | ) |
| 646 | new_data._caches = self.data._caches |
| 647 | elif config.CONV == 'spconv': |
| 648 | new_data = self.SparseTensorData( |
| 649 | self.data.features.reshape(self.data.features.shape[0], -1), |
| 650 | self.data.indices, |
| 651 | self.data.spatial_shape, |
| 652 | self.data.batch_size, |
| 653 | self.data.grid, |
| 654 | self.data.voxel_num, |
| 655 | self.data.indice_dict |
| 656 | ) |
| 657 | new_data._features = feats |
| 658 | new_data.benchmark = self.data.benchmark |
| 659 | new_data.benchmark_record = self.data.benchmark_record |
| 660 | new_data.thrust_allocator = self.data.thrust_allocator |
| 661 | new_data._timer = self.data._timer |
| 662 | new_data.force_algo = self.data.force_algo |
| 663 | new_data.int8_scale = self.data.int8_scale |
| 664 | if coords is not None: |
| 665 | new_data.indices = coords |
| 666 | else: |
| 667 | new_data = { |
| 668 | 'feats': feats, |
| 669 | 'coords': self.data['coords'] if coords is None else coords, |
| 670 | } |
| 671 | new_tensor = SparseTensor( |
| 672 | new_data, |
| 673 | shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None, |
| 674 | scale=self._scale, |
| 675 | spatial_cache=self._spatial_cache |
| 676 | ) |
| 677 | return new_tensor |
| 678 | |
| 679 | def to_dense(self) -> torch.Tensor: |
| 680 | if config.CONV == 'torchsparse': |