(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list,
out_q, initialized_events, work_env)
| 560 | return data |
| 561 | |
| 562 | def mp_worker(self, gpu, gpu_infer, pmi_rank, pmi_world_size, in_q_list, |
| 563 | out_q, initialized_events, work_env): |
| 564 | try: |
| 565 | world_size = pmi_world_size * gpu_infer |
| 566 | rank = pmi_rank * gpu_infer + gpu |
| 567 | print("world_size", world_size, "rank", rank, flush=True) |
| 568 | |
| 569 | torch.cuda.set_device(gpu) |
| 570 | dist.init_process_group( |
| 571 | backend='nccl', |
| 572 | init_method='env://', |
| 573 | rank=rank, |
| 574 | world_size=world_size) |
| 575 | |
| 576 | from xfuser.core.distributed import ( |
| 577 | init_distributed_environment, |
| 578 | initialize_model_parallel, |
| 579 | ) |
| 580 | init_distributed_environment( |
| 581 | rank=dist.get_rank(), world_size=dist.get_world_size()) |
| 582 | |
| 583 | initialize_model_parallel( |
| 584 | sequence_parallel_degree=dist.get_world_size(), |
| 585 | ring_degree=self.ring_size or 1, |
| 586 | ulysses_degree=self.ulysses_size or 1) |
| 587 | |
| 588 | num_train_timesteps = self.config.num_train_timesteps |
| 589 | param_dtype = self.config.param_dtype |
| 590 | shard_fn = partial(shard_model, device_id=gpu) |
| 591 | text_encoder = T5EncoderModel( |
| 592 | text_len=self.config.text_len, |
| 593 | dtype=self.config.t5_dtype, |
| 594 | device=torch.device('cpu'), |
| 595 | checkpoint_path=os.path.join(self.checkpoint_dir, |
| 596 | self.config.t5_checkpoint), |
| 597 | tokenizer_path=os.path.join(self.checkpoint_dir, |
| 598 | self.config.t5_tokenizer), |
| 599 | shard_fn=shard_fn if True else None) |
| 600 | text_encoder.model.to(gpu) |
| 601 | vae_stride = self.config.vae_stride |
| 602 | patch_size = self.config.patch_size |
| 603 | vae = WanVAE( |
| 604 | vae_pth=os.path.join(self.checkpoint_dir, |
| 605 | self.config.vae_checkpoint), |
| 606 | device=gpu) |
| 607 | logging.info(f"Creating VaceWanModel from {self.checkpoint_dir}") |
| 608 | model = VaceWanModel.from_pretrained(self.checkpoint_dir) |
| 609 | model.eval().requires_grad_(False) |
| 610 | |
| 611 | if self.use_usp: |
| 612 | from xfuser.core.distributed import get_sequence_parallel_world_size |
| 613 | |
| 614 | from .distributed.xdit_context_parallel import ( |
| 615 | usp_attn_forward, |
| 616 | usp_dit_forward, |
| 617 | usp_dit_forward_vace, |
| 618 | ) |
| 619 | for block in model.blocks: |
nothing calls this directly
no test coverage detected