(batch_size, cat_dim, cont_dim, noise_dim, noise_scale=0.5)
| 119 | |
| 120 | |
| 121 | def get_gen_batch(batch_size, cat_dim, cont_dim, noise_dim, noise_scale=0.5): |
| 122 | |
| 123 | X_gen = sample_noise(noise_scale, batch_size, noise_dim) |
| 124 | y_gen = np.zeros((X_gen.shape[0], 2), dtype=np.uint8) |
| 125 | y_gen[:, 1] = 1 |
| 126 | |
| 127 | y_cat = sample_cat(batch_size, cat_dim) |
| 128 | y_cont = sample_noise(noise_scale, batch_size, cont_dim) |
| 129 | |
| 130 | # Repeat y_cont to accomodate for keras" loss function conventions |
| 131 | y_cont_target = np.expand_dims(y_cont, 1) |
| 132 | y_cont_target = np.repeat(y_cont_target, 2, axis=1) |
| 133 | |
| 134 | return X_gen, y_gen, y_cat, y_cont, y_cont_target |
| 135 | |
| 136 | |
| 137 | def plot_generated_batch(X_real, generator_model, batch_size, cat_dim, cont_dim, noise_dim, image_data_format, noise_scale=0.5): |
nothing calls this directly
no test coverage detected