(self, args)
| 75 | |
| 76 | class WanInferencePipeline(nn.Module): |
| 77 | def __init__(self, args): |
| 78 | super().__init__() |
| 79 | self.args = args |
| 80 | self.device = torch.device(f"cuda:{args.rank}") |
| 81 | if args.dtype=='bf16': |
| 82 | self.dtype = torch.bfloat16 |
| 83 | elif args.dtype=='fp16': |
| 84 | self.dtype = torch.float16 |
| 85 | else: |
| 86 | self.dtype = torch.float32 |
| 87 | self.pipe = self.load_model() |
| 88 | if args.i2v: |
| 89 | chained_trainsforms = [] |
| 90 | chained_trainsforms.append(TT.ToTensor()) |
| 91 | self.transform = TT.Compose(chained_trainsforms) |
| 92 | if args.use_audio: |
| 93 | from OmniAvatar.models.wav2vec import Wav2VecModel |
| 94 | self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( |
| 95 | args.wav2vec_path |
| 96 | ) |
| 97 | self.audio_encoder = Wav2VecModel.from_pretrained(args.wav2vec_path, local_files_only=True).to(device=self.device) |
| 98 | self.audio_encoder.feature_extractor._freeze_parameters() |
| 99 | |
| 100 | def load_model(self): |
| 101 | dist.init_process_group( |
nothing calls this directly
no test coverage detected