MCPcopy
hub / github.com/hustvl/Vim / __init__

Method __init__

vim/rope.py:97–133  ·  view source on GitHub ↗
(
        self,
        dim,
        pt_seq_len=16,
        ft_seq_len=None,
        custom_freqs = None,
        freqs_for = 'lang',
        theta = 10000,
        max_freq = 10,
        num_freqs = 1,
    )

Source from the content-addressed store, hash-verified

95
96class VisionRotaryEmbeddingFast(nn.Module):
97 def __init__(
98 self,
99 dim,
100 pt_seq_len=16,
101 ft_seq_len=None,
102 custom_freqs = None,
103 freqs_for = 'lang',
104 theta = 10000,
105 max_freq = 10,
106 num_freqs = 1,
107 ):
108 super().__init__()
109 if custom_freqs:
110 freqs = custom_freqs
111 elif freqs_for == 'lang':
112 freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
113 elif freqs_for == 'pixel':
114 freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
115 elif freqs_for == 'constant':
116 freqs = torch.ones(num_freqs).float()
117 else:
118 raise ValueError(f'unknown modality {freqs_for}')
119
120 if ft_seq_len is None: ft_seq_len = pt_seq_len
121 t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
122
123 freqs = torch.einsum('..., f -> ... f', t, freqs)
124 freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
125 freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
126
127 freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
128 freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
129
130 self.register_buffer("freqs_cos", freqs_cos)
131 self.register_buffer("freqs_sin", freqs_sin)
132
133 print('======== shape of rope freq', self.freqs_cos.shape, '========')
134
135 def forward(self, t):
136 if t.shape[1] % 2 != 0:

Callers

nothing calls this directly

Calls 3

broadcatFunction · 0.85
printFunction · 0.85
__init__Method · 0.45

Tested by

no test coverage detected