Generate target domain labels for debugging and testing.
(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None)
| 147 | return out |
| 148 | |
| 149 | def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None): |
| 150 | """Generate target domain labels for debugging and testing.""" |
| 151 | # Get hair color indices. |
| 152 | if dataset == 'CelebA': |
| 153 | hair_color_indices = [] |
| 154 | for i, attr_name in enumerate(selected_attrs): |
| 155 | if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']: |
| 156 | hair_color_indices.append(i) |
| 157 | |
| 158 | c_trg_list = [] |
| 159 | for i in range(c_dim): |
| 160 | if dataset == 'CelebA': |
| 161 | c_trg = c_org.clone() |
| 162 | if i in hair_color_indices: # Set one hair color to 1 and the rest to 0. |
| 163 | c_trg[:, i] = 1 |
| 164 | for j in hair_color_indices: |
| 165 | if j != i: |
| 166 | c_trg[:, j] = 0 |
| 167 | else: |
| 168 | c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value. |
| 169 | elif dataset == 'RaFD': |
| 170 | c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim) |
| 171 | |
| 172 | c_trg_list.append(c_trg.to(self.device)) |
| 173 | return c_trg_list |
| 174 | |
| 175 | def classification_loss(self, logit, target, dataset='CelebA'): |
| 176 | """Compute binary or softmax cross entropy loss.""" |