MCPcopy
hub / github.com/Robbyant/lingbot-world / rope_apply

Function rope_apply

wan/modules/model.py:40–67  ·  view source on GitHub ↗
(x, grid_sizes, freqs)

Source from the content-addressed store, hash-verified

38
39@torch.amp.autocast('cuda', enabled=False)
40def rope_apply(x, grid_sizes, freqs):
41 n, c = x.size(2), x.size(3) // 2
42
43 # split freqs
44 freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
45
46 # loop over samples
47 output = []
48 for i, (f, h, w) in enumerate(grid_sizes.tolist()):
49 seq_len = f * h * w
50
51 # precompute multipliers
52 x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
53 seq_len, n, -1, 2))
54 freqs_i = torch.cat([
55 freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
56 freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
57 freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
58 ],
59 dim=-1).reshape(seq_len, 1, -1)
60
61 # apply rotary embedding
62 x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
63 x_i = torch.cat([x_i, x[i, seq_len:]])
64
65 # append to collection
66 output.append(x_i)
67 return torch.stack(output).float()
68
69
70class WanRMSNorm(nn.Module):

Callers 1

forwardMethod · 0.70

Calls 2

sizeMethod · 0.80
toMethod · 0.80

Tested by

no test coverage detected