MCPcopy
hub / github.com/MeiGen-AI/InfiniteTalk / mp_worker

Method mp_worker

wan/vace.py:562–771  ·  view source on GitHub ↗
(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list,
                  out_q, initialized_events, work_env)

Source from the content-addressed store, hash-verified

560 return data
561
562 def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list,
563 out_q, initialized_events, work_env):
564 try:
565 world_size = pmi_world_size * gpu_infer
566 rank = pmi_rank * gpu_infer + gpu
567 print("world_size", world_size, "rank", rank, flush=True)
568
569 torch.cuda.set_device(gpu)
570 dist.init_process_group(
571 backend='nccl',
572 init_method='env://',
573 rank=rank,
574 world_size=world_size)
575
576 from xfuser.core.distributed import (
577 init_distributed_environment,
578 initialize_model_parallel,
579 )
580 init_distributed_environment(
581 rank=dist.get_rank(), world_size=dist.get_world_size())
582
583 initialize_model_parallel(
584 sequence_parallel_degree=dist.get_world_size(),
585 ring_degree=self.ring_size or 1,
586 ulysses_degree=self.ulysses_size or 1)
587
588 num_train_timesteps = self.config.num_train_timesteps
589 param_dtype = self.config.param_dtype
590 shard_fn = partial(shard_model, device_id=gpu)
591 text_encoder = T5EncoderModel(
592 text_len=self.config.text_len,
593 dtype=self.config.t5_dtype,
594 device=torch.device('cpu'),
595 checkpoint_path=os.path.join(self.checkpoint_dir,
596 self.config.t5_checkpoint),
597 tokenizer_path=os.path.join(self.checkpoint_dir,
598 self.config.t5_tokenizer),
599 shard_fn=shard_fn if True else None)
600 text_encoder.model.to(gpu)
601 vae_stride = self.config.vae_stride
602 patch_size = self.config.patch_size
603 vae = WanVAE(
604 vae_pth=os.path.join(self.checkpoint_dir,
605 self.config.vae_checkpoint),
606 device=gpu)
607 logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}")
608 model = VaceWanModel.from_pretrained(self.checkpoint_dir)
609 model.eval().requires_grad_(False)
610
611 if self.use_usp:
612 from xfuser.core.distributed import get_sequence_parallel_world_size
613
614 from .distributed.xdit_context_parallel import (
615 usp_attn_forward,
616 usp_dit_forward,
617 usp_dit_forward_vace,
618 )
619 for block in model.blocks:

Callers

nothing calls this directly

Calls 14

transfer_data_to_cudaMethod · 0.95
set_timestepsMethod · 0.95
stepMethod · 0.95
T5EncoderModelClass · 0.85
WanVAEClass · 0.85
get_sampling_sigmasFunction · 0.85
retrieve_timestepsFunction · 0.85
deviceMethod · 0.80
vace_encode_framesMethod · 0.80
vace_encode_masksMethod · 0.80

Tested by

no test coverage detected