MCPcopy
hub / github.com/Robbyant/lingbot-world / rope_apply

Function rope_apply

wan/modules/s2v/motioner.py:41–110  ·  view source on GitHub ↗
(x, grid_sizes, freqs, start=None)

Source from the content-addressed store, hash-verified

39
40@amp.autocast(enabled=False)
41def rope_apply(x, grid_sizes, freqs, start=None):
42 n, c = x.size(2), x.size(3) // 2
43
44 # split freqs
45 if type(freqs) is list:
46 trainable_freqs = freqs[1]
47 freqs = freqs[0]
48 freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
49
50 # loop over samples
51 output = []
52 output = x.clone()
53 seq_bucket = [0]
54 if not type(grid_sizes) is list:
55 grid_sizes = [grid_sizes]
56 for g in grid_sizes:
57 if not type(g) is list:
58 g = [torch.zeros_like(g), g]
59 batch_size = g[0].shape[0]
60 for i in range(batch_size):
61 if start is None:
62 f_o, h_o, w_o = g[0][i]
63 else:
64 f_o, h_o, w_o = start[i]
65
66 f, h, w = g[1][i]
67 t_f, t_h, t_w = g[2][i]
68 seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
69 seq_len = int(seq_f * seq_h * seq_w)
70 if seq_len > 0:
71 if t_f > 0:
72 factor_f, factor_h, factor_w = (t_f / seq_f).item(), (
73 t_h / seq_h).item(), (t_w / seq_w).item()
74
75 if f_o >= 0:
76 f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1,
77 seq_f).astype(int).tolist()
78 else:
79 f_sam = np.linspace(-f_o.item(),
80 (-t_f - f_o).item() + 1,
81 seq_f).astype(int).tolist()
82 h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1,
83 seq_h).astype(int).tolist()
84 w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1,
85 seq_w).astype(int).tolist()
86
87 assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
88 freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][
89 f_sam].conj()
90 freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
91
92 freqs_i = torch.cat([
93 freqs_0.expand(seq_f, seq_h, seq_w, -1),
94 freqs[1][h_sam].view(1, seq_h, 1, -1).expand(
95 seq_f, seq_h, seq_w, -1),
96 freqs[2][w_sam].view(1, 1, seq_w, -1).expand(
97 seq_f, seq_h, seq_w, -1),
98 ],

Callers 5

forwardMethod · 0.70
forwardMethod · 0.70
forwardMethod · 0.70
__init__Method · 0.70
forwardMethod · 0.50

Calls 2

sizeMethod · 0.80
toMethod · 0.80

Tested by

no test coverage detected