| 241 | return sparse_unbind(self, dim) |
| 242 | |
| 243 | def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': |
| 244 | new_shape = [self.shape[0]] |
| 245 | new_shape.extend(feats.shape[1:]) |
| 246 | if BACKEND == 'torchsparse': |
| 247 | new_data = SparseTensorData( |
| 248 | feats=feats, |
| 249 | coords=self.data.coords if coords is None else coords, |
| 250 | stride=self.data.stride, |
| 251 | spatial_range=self.data.spatial_range, |
| 252 | ) |
| 253 | new_data._caches = self.data._caches |
| 254 | elif BACKEND == 'spconv': |
| 255 | new_data = SparseTensorData( |
| 256 | self.data.features.reshape(self.data.features.shape[0], -1), |
| 257 | self.data.indices, |
| 258 | self.data.spatial_shape, |
| 259 | self.data.batch_size, |
| 260 | self.data.grid, |
| 261 | self.data.voxel_num, |
| 262 | self.data.indice_dict |
| 263 | ) |
| 264 | new_data._features = feats |
| 265 | new_data.benchmark = self.data.benchmark |
| 266 | new_data.benchmark_record = self.data.benchmark_record |
| 267 | new_data.thrust_allocator = self.data.thrust_allocator |
| 268 | new_data._timer = self.data._timer |
| 269 | new_data.force_algo = self.data.force_algo |
| 270 | new_data.int8_scale = self.data.int8_scale |
| 271 | if coords is not None: |
| 272 | new_data.indices = coords |
| 273 | new_tensor = SparseTensor(new_data, shape=torch.Size(new_shape), layout=self.layout, scale=self._scale, spatial_cache=self._spatial_cache) |
| 274 | return new_tensor |
| 275 | |
| 276 | @staticmethod |
| 277 | def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': |