MCPcopy
hub / github.com/XingangPan/DragGAN / generate_style_mix

Function generate_style_mix

stylegan_human/style_mixing.py:41–100  ·  view source on GitHub ↗
(
    network_pkl: str,
    row_seeds: List[int],
    col_seeds: List[int],
    col_styles: List[int],
    truncation_psi: float,
    noise_mode: str,
    outdir: str
)

Source from the content-addressed store, hash-verified

39@click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True)
40@click.option('--outdir', type=str, required=True, default='outputs/stylemixing')
41def generate_style_mix(
42 network_pkl: str,
43 row_seeds: List[int],
44 col_seeds: List[int],
45 col_styles: List[int],
46 truncation_psi: float,
47 noise_mode: str,
48 outdir: str
49):
50
51 print('Loading networks from "%s"...' % network_pkl)
52 device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
53 dtype = torch.float32 if device.type == 'mps' else torch.float64
54 with dnnlib.util.open_url(network_pkl) as f:
55 G = legacy.load_network_pkl(f)['G_ema'].to(device, dtype=dtype)
56
57 os.makedirs(outdir, exist_ok=True)
58
59 print('Generating W vectors...')
60 all_seeds = list(set(row_seeds + col_seeds))
61 all_z = np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])
62 all_w = G.mapping(torch.from_numpy(all_z).to(device, dtype=dtype), None)
63 w_avg = G.mapping.w_avg
64 all_w = w_avg + (all_w - w_avg) * truncation_psi
65 w_dict = {seed: w for seed, w in zip(all_seeds, list(all_w))}
66
67 print('Generating images...')
68 all_images = G.synthesis(all_w, noise_mode=noise_mode)
69 all_images = (all_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()
70 image_dict = {(seed, seed): image for seed, image in zip(all_seeds, list(all_images))}
71
72 print('Generating style-mixed images...')
73 for row_seed in row_seeds:
74 for col_seed in col_seeds:
75 w = w_dict[row_seed].clone()
76 w[col_styles] = w_dict[col_seed][col_styles]
77 image = G.synthesis(w[np.newaxis], noise_mode=noise_mode)
78 image = (image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
79 image_dict[(row_seed, col_seed)] = image[0].cpu().numpy()
80
81 os.makedirs(outdir, exist_ok=True)
82 # print('Saving images...')
83 # for (row_seed, col_seed), image in image_dict.items():
84 # PIL.Image.fromarray(image, 'RGB').save(f'{outdir}/{row_seed}-{col_seed}.png')
85
86 print('Saving image grid...')
87 W = G.img_resolution // 2
88 H = G.img_resolution
89 canvas = PIL.Image.new('RGB', (W * (len(col_seeds) + 1), H * (len(row_seeds) + 1)), 'black')
90 for row_idx, row_seed in enumerate([0] + row_seeds):
91 for col_idx, col_seed in enumerate([0] + col_seeds):
92 if row_idx == 0 and col_idx == 0:
93 continue
94 key = (row_seed, col_seed)
95 if row_idx == 0:
96 key = (col_seed, col_seed)
97 if col_idx == 0:
98 key = (row_seed, row_seed)

Callers 1

style_mixing.pyFile · 0.85

Calls 1

cloneMethod · 0.80

Tested by

no test coverage detected