MCPcopy
hub / github.com/yunjey/stargan / create_labels

Method create_labels

solver.py:149–173  ·  view source on GitHub ↗

Generate target domain labels for debugging and testing.

(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None)

Source from the content-addressed store, hash-verified

147 return out
148
149 def create_labels(self, c_org, c_dim=5, dataset='CelebA', selected_attrs=None):
150 """Generate target domain labels for debugging and testing."""
151 # Get hair color indices.
152 if dataset == 'CelebA':
153 hair_color_indices = []
154 for i, attr_name in enumerate(selected_attrs):
155 if attr_name in ['Black_Hair', 'Blond_Hair', 'Brown_Hair', 'Gray_Hair']:
156 hair_color_indices.append(i)
157
158 c_trg_list = []
159 for i in range(c_dim):
160 if dataset == 'CelebA':
161 c_trg = c_org.clone()
162 if i in hair_color_indices: # Set one hair color to 1 and the rest to 0.
163 c_trg[:, i] = 1
164 for j in hair_color_indices:
165 if j != i:
166 c_trg[:, j] = 0
167 else:
168 c_trg[:, i] = (c_trg[:, i] == 0) # Reverse attribute value.
169 elif dataset == 'RaFD':
170 c_trg = self.label2onehot(torch.ones(c_org.size(0))*i, c_dim)
171
172 c_trg_list.append(c_trg.to(self.device))
173 return c_trg_list
174
175 def classification_loss(self, logit, target, dataset='CelebA'):
176 """Compute binary or softmax cross entropy loss."""

Callers 4

trainMethod · 0.95
train_multiMethod · 0.95
testMethod · 0.95
test_multiMethod · 0.95

Calls 1

label2onehotMethod · 0.95

Tested by 1

test_multiMethod · 0.76