MCPcopy
hub / github.com/XingangPan/DragGAN / main

Function main

stylegan_human/edit.py:55–190  ·  view source on GitHub ↗
(
    ctx: click.Context,
    ckpt_path: str,
    attr_name: str,
    truncation: float,
    gen_video: bool,
    combine: bool,
    seeds: Optional[List[int]],
    outdir: str,
    real: str,
    real_w_path: str,
    real_img_path: str
)

Source from the content-addressed store, hash-verified

53
54
55def main(
56 ctx: click.Context,
57 ckpt_path: str,
58 attr_name: str,
59 truncation: float,
60 gen_video: bool,
61 combine: bool,
62 seeds: Optional[List[int]],
63 outdir: str,
64 real: str,
65 real_w_path: str,
66 real_img_path: str
67):
68 ## convert pkl to pth
69 # if not os.path.exists(ckpt_path.replace('.pkl','.pth')):
70 legacy.convert(ckpt_path, ckpt_path.replace('.pkl','.pth'), G_only=real)
71 ckpt_path = ckpt_path.replace('.pkl','.pth')
72 print("start...", flush=True)
73 config = {"latent" : 512, "n_mlp" : 8, "channel_multiplier": 2}
74 generator = Generator(
75 size = 1024,
76 style_dim=config["latent"],
77 n_mlp=config["n_mlp"],
78 channel_multiplier=config["channel_multiplier"]
79 )
80
81 generator.load_state_dict(torch.load(ckpt_path)['g_ema'])
82 generator.eval().cuda()
83
84 with torch.no_grad():
85 mean_path = os.path.join('edit','mean_latent.pkl')
86 if not os.path.exists(mean_path):
87 mean_n = 3000
88 mean_latent = generator.mean_latent(mean_n).detach()
89 legacy.save_obj(mean_latent, mean_path)
90 else:
91 mean_latent = legacy.load_pkl(mean_path).cuda()
92 finals = []
93
94 ## -- selected sample seeds -- ##
95 # seeds = [60948,60965,61174,61210,61511,61598,61610] #bottom -> long
96 # [60941,61064,61103,61313,61531,61570,61571] # bottom -> short
97 # [60941,60965,61064,61103,6117461210,61531,61570,61571,61610] # upper --> long
98 # [60948,61313,61511,61598] # upper --> short
99 if real: seeds = [0]
100
101 for t in seeds:
102 if real: # now assume process single real image only
103 if real_img_path:
104 real_image = cv2.imread(real_img_path)
105 real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB)
106 import torchvision.transforms as transforms
107 transform = transforms.Compose( # normalize to (-1, 1)
108 [transforms.ToTensor(),
109 transforms.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5))]
110 )
111 real_image = transform(real_image).unsqueeze(0).cuda()
112

Callers 1

edit.pyFile · 0.70

Calls 8

mean_latentMethod · 0.95
GeneratorClass · 0.90
encoder_ifgFunction · 0.90
decoderFunction · 0.90
encoder_ssFunction · 0.90
encoder_sefaFunction · 0.90
loadMethod · 0.80
convertMethod · 0.45

Tested by

no test coverage detected