MCPcopy
hub / github.com/TencentARC/Pixal3D / forward

Method forward

pixal3d/modules/attention/rope.py:35–48  ·  view source on GitHub ↗

Args: indices (torch.Tensor): [..., N, C] tensor of spatial positions

(self, indices: torch.Tensor)

Source from the content-addressed store, hash-verified

33 return x_embed
34
35 def forward(self, indices: torch.Tensor) -> torch.Tensor:
36 """
37 Args:
38 indices (torch.Tensor): [..., N, C] tensor of spatial positions
39 """
40 assert indices.shape[-1] == self.dim, f"Last dim of indices must be {self.dim}"
41 phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
42 if phases.shape[-1] < self.head_dim // 2:
43 padn = self.head_dim // 2 - phases.shape[-1]
44 phases = torch.cat([phases, torch.polar(
45 torch.ones(*phases.shape[:-1], padn, device=phases.device),
46 torch.zeros(*phases.shape[:-1], padn, device=phases.device)
47 )], dim=-1)
48 return phases

Callers

nothing calls this directly

Calls 2

_get_phasesMethod · 0.95
reshapeMethod · 0.45

Tested by

no test coverage detected