MCPcopy
hub / github.com/hustvl/Vim / __init__

Method __init__

vim/rope.py:47–84  ·  view source on GitHub ↗
(
        self,
        dim,
        pt_seq_len,
        ft_seq_len=None,
        custom_freqs = None,
        freqs_for = 'lang',
        theta = 10000,
        max_freq = 10,
        num_freqs = 1,
    )

Source from the content-addressed store, hash-verified

45
46class VisionRotaryEmbedding(nn.Module):
47 def __init__(
48 self,
49 dim,
50 pt_seq_len,
51 ft_seq_len=None,
52 custom_freqs = None,
53 freqs_for = 'lang',
54 theta = 10000,
55 max_freq = 10,
56 num_freqs = 1,
57 ):
58 super().__init__()
59 if custom_freqs:
60 freqs = custom_freqs
61 elif freqs_for == 'lang':
62 freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
63 elif freqs_for == 'pixel':
64 freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
65 elif freqs_for == 'constant':
66 freqs = torch.ones(num_freqs).float()
67 else:
68 raise ValueError(f'unknown modality {freqs_for}')
69
70 if ft_seq_len is None: ft_seq_len = pt_seq_len
71 t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
72
73 freqs_h = torch.einsum('..., f -> ... f', t, freqs)
74 freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
75
76 freqs_w = torch.einsum('..., f -> ... f', t, freqs)
77 freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
78
79 freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
80
81 self.register_buffer("freqs_cos", freqs.cos())
82 self.register_buffer("freqs_sin", freqs.sin())
83
84 print('======== shape of rope freq', self.freqs_cos.shape, '========')
85
86 def forward(self, t, start_index = 0):
87 rot_dim = self.freqs_cos.shape[-1]

Callers 1

__init__Method · 0.45

Calls 2

broadcatFunction · 0.85
printFunction · 0.85

Tested by

no test coverage detected