MCPcopy
hub / github.com/Robbyant/lingbot-world / forward

Method forward

wan/modules/animate/model_animate.py:372–450  ·  view source on GitHub ↗
(
        self,
        x,
        t,
        clip_fea,
        context,
        seq_len,
        y=None,
        pose_latents=None, 
        face_pixel_values=None
    )

Source from the content-addressed store, hash-verified

370
371
372 def forward(
373 self,
374 x,
375 t,
376 clip_fea,
377 context,
378 seq_len,
379 y=None,
380 pose_latents=None,
381 face_pixel_values=None
382 ):
383 # params
384 device = self.patch_embedding.weight.device
385 if self.freqs.device != device:
386 self.freqs = self.freqs.to(device)
387
388 if y is not None:
389 x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
390
391 # embeddings
392 x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
393 x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)
394
395 grid_sizes = torch.stack(
396 [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
397 x = [u.flatten(2).transpose(1, 2) for u in x]
398 seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
399 assert seq_lens.max() <= seq_len
400 x = torch.cat([
401 torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
402 dim=1) for u in x
403 ])
404
405 # time embeddings
406 with amp.autocast(dtype=torch.float32):
407 e = self.time_embedding(
408 sinusoidal_embedding_1d(self.freq_dim, t).float()
409 )
410 e0 = self.time_projection(e).unflatten(1, (6, self.dim))
411 assert e.dtype == torch.float32 and e0.dtype == torch.float32
412
413 # context
414 context_lens = None
415 context = self.text_embedding(
416 torch.stack([
417 torch.cat(
418 [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
419 for u in context
420 ]))
421
422 if self.use_img_emb:
423 context_clip = self.img_emb(clip_fea) # bs x 257 x dim
424 context = torch.concat([context_clip, context], dim=1)
425
426 # arguments
427 kwargs = dict(
428 e=e0,
429 seq_lens=seq_lens,

Callers

nothing calls this directly

Calls 9

after_patch_embeddingMethod · 0.95
unpatchifyMethod · 0.95
get_world_sizeFunction · 0.85
get_rankFunction · 0.85
gather_forwardFunction · 0.85
toMethod · 0.80
sizeMethod · 0.80
sinusoidal_embedding_1dFunction · 0.50

Tested by

no test coverage detected