(
ctx: click.Context,
ckpt_path: str,
attr_name: str,
truncation: float,
gen_video: bool,
combine: bool,
seeds: Optional[List[int]],
outdir: str,
real: str,
real_w_path: str,
real_img_path: str
)
| 53 | |
| 54 | |
| 55 | def main( |
| 56 | ctx: click.Context, |
| 57 | ckpt_path: str, |
| 58 | attr_name: str, |
| 59 | truncation: float, |
| 60 | gen_video: bool, |
| 61 | combine: bool, |
| 62 | seeds: Optional[List[int]], |
| 63 | outdir: str, |
| 64 | real: str, |
| 65 | real_w_path: str, |
| 66 | real_img_path: str |
| 67 | ): |
| 68 | ## convert pkl to pth |
| 69 | # if not os.path.exists(ckpt_path.replace('.pkl','.pth')): |
| 70 | legacy.convert(ckpt_path, ckpt_path.replace('.pkl','.pth'), G_only=real) |
| 71 | ckpt_path = ckpt_path.replace('.pkl','.pth') |
| 72 | print("start...", flush=True) |
| 73 | config = {"latent" : 512, "n_mlp" : 8, "channel_multiplier": 2} |
| 74 | generator = Generator( |
| 75 | size = 1024, |
| 76 | style_dim=config["latent"], |
| 77 | n_mlp=config["n_mlp"], |
| 78 | channel_multiplier=config["channel_multiplier"] |
| 79 | ) |
| 80 | |
| 81 | generator.load_state_dict(torch.load(ckpt_path)['g_ema']) |
| 82 | generator.eval().cuda() |
| 83 | |
| 84 | with torch.no_grad(): |
| 85 | mean_path = os.path.join('edit','mean_latent.pkl') |
| 86 | if not os.path.exists(mean_path): |
| 87 | mean_n = 3000 |
| 88 | mean_latent = generator.mean_latent(mean_n).detach() |
| 89 | legacy.save_obj(mean_latent, mean_path) |
| 90 | else: |
| 91 | mean_latent = legacy.load_pkl(mean_path).cuda() |
| 92 | finals = [] |
| 93 | |
| 94 | ## -- selected sample seeds -- ## |
| 95 | # seeds = [60948,60965,61174,61210,61511,61598,61610] #bottom -> long |
| 96 | # [60941,61064,61103,61313,61531,61570,61571] # bottom -> short |
| 97 | # [60941,60965,61064,61103,6117461210,61531,61570,61571,61610] # upper --> long |
| 98 | # [60948,61313,61511,61598] # upper --> short |
| 99 | if real: seeds = [0] |
| 100 | |
| 101 | for t in seeds: |
| 102 | if real: # now assume process single real image only |
| 103 | if real_img_path: |
| 104 | real_image = cv2.imread(real_img_path) |
| 105 | real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB) |
| 106 | import torchvision.transforms as transforms |
| 107 | transform = transforms.Compose( # normalize to (-1, 1) |
| 108 | [transforms.ToTensor(), |
| 109 | transforms.Normalize(mean=(.5,.5,.5), std=(.5,.5,.5))] |
| 110 | ) |
| 111 | real_image = transform(real_image).unsqueeze(0).cuda() |
| 112 |
no test coverage detected