MCPcopy
hub / github.com/Rudrabha/Wav2Lip / Wav2Lip_disc_qual

Class Wav2Lip_disc_qual

models/wav2lip.py:127–184  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

125 return outputs
126
127class Wav2Lip_disc_qual(nn.Module):
128 def __init__(self):
129 super(Wav2Lip_disc_qual, self).__init__()
130
131 self.face_encoder_blocks = nn.ModuleList([
132 nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96
133
134 nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48
135 nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)),
136
137 nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24
138 nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)),
139
140 nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12
141 nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)),
142
143 nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6
144 nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)),
145
146 nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3
147 nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),),
148
149 nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1
150 nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),])
151
152 self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid())
153 self.label_noise = .0
154
155 def get_lower_half(self, face_sequences):
156 return face_sequences[:, :, face_sequences.size(2)//2:]
157
158 def to_2d(self, face_sequences):
159 B = face_sequences.size(0)
160 face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0)
161 return face_sequences
162
163 def perceptual_forward(self, false_face_sequences):
164 false_face_sequences = self.to_2d(false_face_sequences)
165 false_face_sequences = self.get_lower_half(false_face_sequences)
166
167 false_feats = false_face_sequences
168 for f in self.face_encoder_blocks:
169 false_feats = f(false_feats)
170
171 false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1),
172 torch.ones((len(false_feats), 1)).cuda())
173
174 return false_pred_loss
175
176 def forward(self, face_sequences):
177 face_sequences = self.to_2d(face_sequences)
178 face_sequences = self.get_lower_half(face_sequences)
179
180 x = face_sequences
181 for f in self.face_encoder_blocks:
182 x = f(x)
183
184 return self.binary_pred(x).view(len(x), -1)

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected