MCPcopy
hub / github.com/TencentARC/Pixal3D / forward

Method forward

pixal3d/modules/sparse/spatial/basic.py:23–68  ·  view source on GitHub ↗
(self, x: SparseTensor)

Source from the content-addressed store, hash-verified

21 assert self.mode in ['mean', 'max'], f'Invalid mode: {self.mode}'
22
23 def forward(self, x: SparseTensor) -> SparseTensor:
24 cache = x.get_spatial_cache(f'downsample_{self.factor}')
25 if cache is None:
26 DIM = x.coords.shape[-1] - 1
27
28 coord = list(x.coords.unbind(dim=-1))
29 for i in range(DIM):
30 coord[i+1] = coord[i+1] // self.factor
31
32 MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape]
33 OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
34 code = sum([c * o for c, o in zip(coord, OFFSET)])
35 code, idx = code.unique(return_inverse=True)
36
37 new_coords = torch.stack(
38 [code // OFFSET[0]] +
39 [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
40 dim=-1
41 )
42 else:
43 new_coords, idx = cache
44
45 new_feats = torch.scatter_reduce(
46 torch.zeros(new_coords.shape[0], x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype),
47 dim=0,
48 index=idx.unsqueeze(1).expand(-1, x.feats.shape[1]),
49 src=x.feats,
50 reduce=self.mode,
51 include_self=False,
52 )
53 out = SparseTensor(new_feats, new_coords, x._shape)
54 out._scale = tuple([s * self.factor for s in x._scale])
55 out._spatial_cache = x._spatial_cache
56
57 if cache is None:
58 x.register_spatial_cache(f'downsample_{self.factor}', (new_coords, idx))
59 out.register_spatial_cache(f'upsample_{self.factor}', (x.coords, idx))
60 out.register_spatial_cache(f'shape', torch.Size(MAX))
61 if self.training:
62 subidx = x.coords[:, 1:] % self.factor
63 subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)])
64 subdivision = torch.zeros((new_coords.shape[0], self.factor ** DIM), device=x.device, dtype=torch.bool)
65 subdivision[idx, subidx] = True
66 out.register_spatial_cache(f'subdivision', subdivision)
67
68 return out
69
70
71class SparseUpsample(nn.Module):

Callers

nothing calls this directly

Calls 4

SparseTensorClass · 0.85
get_spatial_cacheMethod · 0.80
unbindMethod · 0.45

Tested by

no test coverage detected