(self, x: SparseTensor)
| 21 | assert self.mode in ['mean', 'max'], f'Invalid mode: {self.mode}' |
| 22 | |
| 23 | def forward(self, x: SparseTensor) -> SparseTensor: |
| 24 | cache = x.get_spatial_cache(f'downsample_{self.factor}') |
| 25 | if cache is None: |
| 26 | DIM = x.coords.shape[-1] - 1 |
| 27 | |
| 28 | coord = list(x.coords.unbind(dim=-1)) |
| 29 | for i in range(DIM): |
| 30 | coord[i+1] = coord[i+1] // self.factor |
| 31 | |
| 32 | MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape] |
| 33 | OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] |
| 34 | code = sum([c * o for c, o in zip(coord, OFFSET)]) |
| 35 | code, idx = code.unique(return_inverse=True) |
| 36 | |
| 37 | new_coords = torch.stack( |
| 38 | [code // OFFSET[0]] + |
| 39 | [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], |
| 40 | dim=-1 |
| 41 | ) |
| 42 | else: |
| 43 | new_coords, idx = cache |
| 44 | |
| 45 | new_feats = torch.scatter_reduce( |
| 46 | torch.zeros(new_coords.shape[0], x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype), |
| 47 | dim=0, |
| 48 | index=idx.unsqueeze(1).expand(-1, x.feats.shape[1]), |
| 49 | src=x.feats, |
| 50 | reduce=self.mode, |
| 51 | include_self=False, |
| 52 | ) |
| 53 | out = SparseTensor(new_feats, new_coords, x._shape) |
| 54 | out._scale = tuple([s * self.factor for s in x._scale]) |
| 55 | out._spatial_cache = x._spatial_cache |
| 56 | |
| 57 | if cache is None: |
| 58 | x.register_spatial_cache(f'downsample_{self.factor}', (new_coords, idx)) |
| 59 | out.register_spatial_cache(f'upsample_{self.factor}', (x.coords, idx)) |
| 60 | out.register_spatial_cache(f'shape', torch.Size(MAX)) |
| 61 | if self.training: |
| 62 | subidx = x.coords[:, 1:] % self.factor |
| 63 | subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)]) |
| 64 | subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool) |
| 65 | subdivision[idx, subidx] = True |
| 66 | out.register_spatial_cache(f'subdivision', subdivision) |
| 67 | |
| 68 | return out |
| 69 | |
| 70 | |
| 71 | class SparseUpsample(nn.Module): |
nothing calls this directly
no test coverage detected