| 125 | return outputs |
| 126 | |
| 127 | class Wav2Lip_disc_qual(nn.Module): |
| 128 | def __init__(self): |
| 129 | super(Wav2Lip_disc_qual, self).__init__() |
| 130 | |
| 131 | self.face_encoder_blocks = nn.ModuleList([ |
| 132 | nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96 |
| 133 | |
| 134 | nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48 |
| 135 | nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)), |
| 136 | |
| 137 | nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24 |
| 138 | nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)), |
| 139 | |
| 140 | nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12 |
| 141 | nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)), |
| 142 | |
| 143 | nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6 |
| 144 | nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)), |
| 145 | |
| 146 | nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3 |
| 147 | nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),), |
| 148 | |
| 149 | nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 |
| 150 | nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) |
| 151 | |
| 152 | self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid()) |
| 153 | self.label_noise = .0 |
| 154 | |
| 155 | def get_lower_half(self, face_sequences): |
| 156 | return face_sequences[:, :, face_sequences.size(2)//2:] |
| 157 | |
| 158 | def to_2d(self, face_sequences): |
| 159 | B = face_sequences.size(0) |
| 160 | face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) |
| 161 | return face_sequences |
| 162 | |
| 163 | def perceptual_forward(self, false_face_sequences): |
| 164 | false_face_sequences = self.to_2d(false_face_sequences) |
| 165 | false_face_sequences = self.get_lower_half(false_face_sequences) |
| 166 | |
| 167 | false_feats = false_face_sequences |
| 168 | for f in self.face_encoder_blocks: |
| 169 | false_feats = f(false_feats) |
| 170 | |
| 171 | false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1), |
| 172 | torch.ones((len(false_feats), 1)).cuda()) |
| 173 | |
| 174 | return false_pred_loss |
| 175 | |
| 176 | def forward(self, face_sequences): |
| 177 | face_sequences = self.to_2d(face_sequences) |
| 178 | face_sequences = self.get_lower_half(face_sequences) |
| 179 | |
| 180 | x = face_sequences |
| 181 | for f in self.face_encoder_blocks: |
| 182 | x = f(x) |
| 183 | |
| 184 | return self.binary_pred(x).view(len(x), -1) |