MCPcopy
hub / github.com/Omni-Avatar/OmniAvatar / __init__

Method __init__

scripts/inference.py:77–98  ·  view source on GitHub ↗
(self, args)

Source from the content-addressed store, hash-verified

75
76class WanInferencePipeline(nn.Module):
77 def __init__(self, args):
78 super().__init__()
79 self.args = args
80 self.device = torch.device(f"cuda:{args.rank}")
81 if args.dtype=='bf16':
82 self.dtype = torch.bfloat16
83 elif args.dtype=='fp16':
84 self.dtype = torch.float16
85 else:
86 self.dtype = torch.float32
87 self.pipe = self.load_model()
88 if args.i2v:
89 chained_trainsforms = []
90 chained_trainsforms.append(TT.ToTensor())
91 self.transform = TT.Compose(chained_trainsforms)
92 if args.use_audio:
93 from OmniAvatar.models.wav2vec import Wav2VecModel
94 self.wav_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
95 args.wav2vec_path
96 )
97 self.audio_encoder = Wav2VecModel.from_pretrained(args.wav2vec_path, local_files_only=True).to(device=self.device)
98 self.audio_encoder.feature_extractor._freeze_parameters()
99
100 def load_model(self):
101 dist.init_process_group(

Callers

nothing calls this directly

Calls 2

load_modelMethod · 0.95
toMethod · 0.80

Tested by

no test coverage detected