| 45 | |
| 46 | class VisionRotaryEmbedding(nn.Module): |
| 47 | def __init__( |
| 48 | self, |
| 49 | dim, |
| 50 | pt_seq_len, |
| 51 | ft_seq_len=None, |
| 52 | custom_freqs = None, |
| 53 | freqs_for = 'lang', |
| 54 | theta = 10000, |
| 55 | max_freq = 10, |
| 56 | num_freqs = 1, |
| 57 | ): |
| 58 | super().__init__() |
| 59 | if custom_freqs: |
| 60 | freqs = custom_freqs |
| 61 | elif freqs_for == 'lang': |
| 62 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) |
| 63 | elif freqs_for == 'pixel': |
| 64 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi |
| 65 | elif freqs_for == 'constant': |
| 66 | freqs = torch.ones(num_freqs).float() |
| 67 | else: |
| 68 | raise ValueError(f'unknown modality {freqs_for}') |
| 69 | |
| 70 | if ft_seq_len is None: ft_seq_len = pt_seq_len |
| 71 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len |
| 72 | |
| 73 | freqs_h = torch.einsum('..., f -> ... f', t, freqs) |
| 74 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) |
| 75 | |
| 76 | freqs_w = torch.einsum('..., f -> ... f', t, freqs) |
| 77 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) |
| 78 | |
| 79 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) |
| 80 | |
| 81 | self.register_buffer("freqs_cos", freqs.cos()) |
| 82 | self.register_buffer("freqs_sin", freqs.sin()) |
| 83 | |
| 84 | print('======== shape of rope freq', self.freqs_cos.shape, '========') |
| 85 | |
| 86 | def forward(self, t, start_index = 0): |
| 87 | rot_dim = self.freqs_cos.shape[-1] |