MCPcopy
hub / github.com/tdeboissiere/DeepLearningImplementations / get_disc_batch

Function get_disc_batch

InfoGAN/src/utils/data_utils.py:80–118  ·  view source on GitHub ↗
(X_real_batch, generator_model, batch_counter, batch_size, cat_dim, cont_dim, noise_dim,
                   noise_scale=0.5, label_smoothing=False, label_flipping=0)

Source from the content-addressed store, hash-verified

78
79
80def get_disc_batch(X_real_batch, generator_model, batch_counter, batch_size, cat_dim, cont_dim, noise_dim,
81 noise_scale=0.5, label_smoothing=False, label_flipping=0):
82
83 # Create X_disc: alternatively only generated or real images
84 if batch_counter % 2 == 0:
85 # Pass noise to the generator
86 y_cat = sample_cat(batch_size, cat_dim)
87 y_cont = sample_noise(noise_scale, batch_size, cont_dim)
88 noise_input = sample_noise(noise_scale, batch_size, noise_dim)
89 # Produce an output
90 X_disc = generator_model.predict([y_cat, y_cont, noise_input],batch_size=batch_size)
91 y_disc = np.zeros((X_disc.shape[0], 2), dtype=np.uint8)
92 y_disc[:, 0] = 1
93
94 if label_flipping > 0:
95 p = np.random.binomial(1, label_flipping)
96 if p > 0:
97 y_disc[:, [0, 1]] = y_disc[:, [1, 0]]
98
99 else:
100 X_disc = X_real_batch
101 y_disc = np.zeros((X_disc.shape[0], 2), dtype=np.uint8)
102 y_cat = sample_cat(batch_size, cat_dim)
103 y_cont = sample_noise(noise_scale, batch_size, cont_dim)
104 if label_smoothing:
105 y_disc[:, 1] = np.random.uniform(low=0.9, high=1, size=y_disc.shape[0])
106 else:
107 y_disc[:, 1] = 1
108
109 if label_flipping > 0:
110 p = np.random.binomial(1, label_flipping)
111 if p > 0:
112 y_disc[:, [0, 1]] = y_disc[:, [1, 0]]
113
114 # Repeat y_cont to accomodate for keras" loss function conventions
115 y_cont = np.expand_dims(y_cont, 1)
116 y_cont = np.repeat(y_cont, 2, axis=1)
117
118 return X_disc, y_disc, y_cat, y_cont
119
120
121def get_gen_batch(batch_size, cat_dim, cont_dim, noise_dim, noise_scale=0.5):

Callers

nothing calls this directly

Calls 2

sample_catFunction · 0.85
sample_noiseFunction · 0.70

Tested by

no test coverage detected