| 95 | |
| 96 | class VisionRotaryEmbeddingFast(nn.Module): |
| 97 | def __init__( |
| 98 | self, |
| 99 | dim, |
| 100 | pt_seq_len=16, |
| 101 | ft_seq_len=None, |
| 102 | custom_freqs = None, |
| 103 | freqs_for = 'lang', |
| 104 | theta = 10000, |
| 105 | max_freq = 10, |
| 106 | num_freqs = 1, |
| 107 | ): |
| 108 | super().__init__() |
| 109 | if custom_freqs: |
| 110 | freqs = custom_freqs |
| 111 | elif freqs_for == 'lang': |
| 112 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) |
| 113 | elif freqs_for == 'pixel': |
| 114 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi |
| 115 | elif freqs_for == 'constant': |
| 116 | freqs = torch.ones(num_freqs).float() |
| 117 | else: |
| 118 | raise ValueError(f'unknown modality {freqs_for}') |
| 119 | |
| 120 | if ft_seq_len is None: ft_seq_len = pt_seq_len |
| 121 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len |
| 122 | |
| 123 | freqs = torch.einsum('..., f -> ... f', t, freqs) |
| 124 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2) |
| 125 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) |
| 126 | |
| 127 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) |
| 128 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) |
| 129 | |
| 130 | self.register_buffer("freqs_cos", freqs_cos) |
| 131 | self.register_buffer("freqs_sin", freqs_sin) |
| 132 | |
| 133 | print('======== shape of rope freq', self.freqs_cos.shape, '========') |
| 134 | |
| 135 | def forward(self, t): |
| 136 | if t.shape[1] % 2 != 0: |