MCPcopy
hub / github.com/Robbyant/lingbot-world / causal_rope_apply

Function causal_rope_apply

wan/distributed/sequence_parallel.py:67–95  ·  view source on GitHub ↗
(x, grid_sizes, freqs, start_frame=0)

Source from the content-addressed store, hash-verified

65
66@torch.amp.autocast('cuda', enabled=False)
67def causal_rope_apply(x, grid_sizes, freqs, start_frame=0):
68 n, c = x.size(2), x.size(3) // 2
69
70 # split freqs
71 freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
72
73 # loop over samples
74 output = []
75
76 for i, (f, h, w) in enumerate(grid_sizes.tolist()):
77 seq_len = f * h * w
78
79 # precompute multipliers
80 x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(
81 seq_len, n, -1, 2))
82 freqs_i = torch.cat([
83 freqs[0][start_frame:start_frame + f].view(f, 1, 1, -1).expand(f, h, w, -1),
84 freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
85 freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
86 ],
87 dim=-1).reshape(seq_len, 1, -1)
88
89 # apply rotary embedding
90 x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
91 x_i = torch.cat([x_i, x[i, seq_len:]])
92
93 # append to collection
94 output.append(x_i)
95 return torch.stack(output).type_as(x)
96
97
98def sp_dit_forward(

Callers 1

sp_attn_forward_causalFunction · 0.70

Calls 3

sizeMethod · 0.80
toMethod · 0.80
type_asMethod · 0.80

Tested by

no test coverage detected