MCPcopy
hub / github.com/Robbyant/lingbot-world / sp_dit_forward

Function sp_dit_forward

wan/distributed/sequence_parallel.py:98–214  ·  view source on GitHub ↗

x: A list of videos each with shape [C, T, H, W]. t: [B]. context: A list of text embeddings each with shape [L, C].

(
    self,
    x,
    t,
    context,
    seq_len,
    y=None,
    dit_cond_dict=None,
)

Source from the content-addressed store, hash-verified

96
97
98def sp_dit_forward(
99 self,
100 x,
101 t,
102 context,
103 seq_len,
104 y=None,
105 dit_cond_dict=None,
106):
107 """
108 x: A list of videos each with shape [C, T, H, W].
109 t: [B].
110 context: A list of text embeddings each with shape [L, C].
111 """
112 if self.model_type == 'i2v':
113 assert y is not None
114 # params
115 device = self.patch_embedding.weight.device
116 if self.freqs.device != device:
117 self.freqs = self.freqs.to(device)
118
119 if y is not None:
120 x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
121
122 # embeddings
123 x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
124 grid_sizes = torch.stack(
125 [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
126 x = [u.flatten(2).transpose(1, 2) for u in x]
127 seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
128 assert seq_lens.max() <= seq_len
129 x = torch.cat([
130 torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
131 for u in x
132 ])
133
134 # time embeddings
135 if t.dim() == 1:
136 t = t.expand(t.size(0), seq_len)
137 with torch.amp.autocast('cuda', dtype=torch.float32):
138 bt = t.size(0)
139 t = t.flatten()
140 e = self.time_embedding(
141 sinusoidal_embedding_1d(self.freq_dim,
142 t).unflatten(0, (bt, seq_len)).float())
143 e0 = self.time_projection(e).unflatten(2, (6, self.dim))
144 assert e.dtype == torch.float32 and e0.dtype == torch.float32
145
146 # context
147 context_lens = None
148 context = self.text_embedding(
149 torch.stack([
150 torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
151 for u in context
152 ]))
153
154 # cam
155 if dit_cond_dict is not None and "c2ws_plucker_emb" in dit_cond_dict:

Callers

nothing calls this directly

Calls 7

get_world_sizeFunction · 0.85
get_rankFunction · 0.85
gather_forwardFunction · 0.85
toMethod · 0.80
sizeMethod · 0.80
sinusoidal_embedding_1dFunction · 0.50
unpatchifyMethod · 0.45

Tested by

no test coverage detected