MCPcopy
hub / github.com/Robbyant/lingbot-world / rope_apply

Function rope_apply

wan/distributed/sequence_parallel.py:26–63  ·  view source on GitHub ↗

x: [B, L, N, C]. grid_sizes: [B, 3]. freqs: [M, C // 2].

(x, grid_sizes, freqs)

Source from the content-addressed store, hash-verified

24
25@torch.amp.autocast('cuda', enabled=False)
26def rope_apply(x, grid_sizes, freqs):
27 """
28 x: [B, L, N, C].
29 grid_sizes: [B, 3].
30 freqs: [M, C // 2].
31 """
32 s, n, c = x.size(1), x.size(2), x.size(3) // 2
33 # split freqs
34 freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
35
36 # loop over samples
37 output = []
38 for i, (f, h, w) in enumerate(grid_sizes.tolist()):
39 seq_len = f * h * w
40
41 # precompute multipliers
42 x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
43 s, n, -1, 2))
44 freqs_i = torch.cat([
45 freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
46 freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
47 freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
48 ],
49 dim=-1).reshape(seq_len, 1, -1)
50
51 # apply rotary embedding
52 sp_size = get_world_size()
53 sp_rank = get_rank()
54 freqs_i = pad_freqs(freqs_i, s * sp_size)
55 s_per_rank = s
56 freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
57 s_per_rank), :, :]
58 x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
59 x_i = torch.cat([x_i, x[i, s:]])
60
61 # append to collection
62 output.append(x_i)
63 return torch.stack(output).float()
64
65
66@torch.amp.autocast('cuda', enabled=False)

Callers 1

sp_attn_forwardFunction · 0.70

Calls 5

get_world_sizeFunction · 0.85
get_rankFunction · 0.85
pad_freqsFunction · 0.85
sizeMethod · 0.80
toMethod · 0.80

Tested by

no test coverage detected