| 39 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) |
| 40 | @click.option('--outdir', type=str, required=True, default='outputs/stylemixing') |
| 41 | def generate_style_mix( |
| 42 | network_pkl: str, |
| 43 | row_seeds: List[int], |
| 44 | col_seeds: List[int], |
| 45 | col_styles: List[int], |
| 46 | truncation_psi: float, |
| 47 | noise_mode: str, |
| 48 | outdir: str |
| 49 | ): |
| 50 | |
| 51 | print('Loading networks from "%s"...' % network_pkl) |
| 52 | device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') |
| 53 | dtype = torch.float32 if device.type == 'mps' else torch.float64 |
| 54 | with dnnlib.util.open_url(network_pkl) as f: |
| 55 | G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype) |
| 56 | |
| 57 | os.makedirs(outdir, exist_ok=True) |
| 58 | |
| 59 | print('Generating W vectors...') |
| 60 | all_seeds = list(set(row_seeds + col_seeds)) |
| 61 | all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds]) |
| 62 | all_w = G.mapping(torch.from_numpy(all_z).to(device, dtype=dtype), None) |
| 63 | w_avg = G.mapping.w_avg |
| 64 | all_w = w_avg + (all_w - w_avg) * truncation_psi |
| 65 | w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))} |
| 66 | |
| 67 | print('Generating images...') |
| 68 | all_images = G.synthesis(all_w, noise_mode=noise_mode) |
| 69 | all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy() |
| 70 | image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))} |
| 71 | |
| 72 | print('Generating style-mixed images...') |
| 73 | for row_seed in row_seeds: |
| 74 | for col_seed in col_seeds: |
| 75 | w = w_dict[row_seed].clone() |
| 76 | w[col_styles] = w_dict[col_seed][col_styles] |
| 77 | image = G.synthesis(w[np.newaxis], noise_mode=noise_mode) |
| 78 | image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) |
| 79 | image_dict[(row_seed, col_seed)] = image[0].cpu().numpy() |
| 80 | |
| 81 | os.makedirs(outdir, exist_ok=True) |
| 82 | # print('Saving images...') |
| 83 | # for (row_seed, col_seed), image in image_dict.items(): |
| 84 | # PIL.Image.fromarray(image, 'RGB').save(f'{outdir}/{row_seed}-{col_seed}.png') |
| 85 | |
| 86 | print('Saving image grid...') |
| 87 | W = G.img_resolution // 2 |
| 88 | H = G.img_resolution |
| 89 | canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black') |
| 90 | for row_idx, row_seed in enumerate([0] + row_seeds): |
| 91 | for col_idx, col_seed in enumerate([0] + col_seeds): |
| 92 | if row_idx == 0 and col_idx == 0: |
| 93 | continue |
| 94 | key = (row_seed, col_seed) |
| 95 | if row_idx == 0: |
| 96 | key = (col_seed, col_seed) |
| 97 | if col_idx == 0: |
| 98 | key = (row_seed, row_seed) |