(G, mp4: str, seeds, pose_cond, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, truncation_cutoff=14, cfg='FFHQ', image_mode='image', gen_shapes=False, device=torch.device('cuda'), **video_kwargs)
| 67 | #---------------------------------------------------------------------------- |
| 68 | |
| 69 | def gen_interp_video(G, mp4: str, seeds, pose_cond, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, truncation_cutoff=14, cfg='FFHQ', image_mode='image', gen_shapes=False, device=torch.device('cuda'), **video_kwargs): |
| 70 | grid_w = grid_dims[0] |
| 71 | grid_h = grid_dims[1] |
| 72 | |
| 73 | if num_keyframes is None: |
| 74 | if len(seeds) % (grid_w*grid_h) != 0: |
| 75 | raise ValueError('Number of input seeds must be divisible by grid W*H') |
| 76 | num_keyframes = len(seeds) // (grid_w*grid_h) |
| 77 | |
| 78 | all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64) |
| 79 | for idx in range(num_keyframes*grid_h*grid_w): |
| 80 | all_seeds[idx] = seeds[idx % len(seeds)] |
| 81 | |
| 82 | if shuffle_seed is not None: |
| 83 | rng = np.random.RandomState(seed=shuffle_seed) |
| 84 | rng.shuffle(all_seeds) |
| 85 | |
| 86 | camera_lookat_point = torch.tensor([0, 0, 0.2], device=device) if cfg == 'FFHQ' else torch.tensor([0, 0, 0], device=device) |
| 87 | |
| 88 | zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device) |
| 89 | pose_cond_rad = pose_cond/180*np.pi |
| 90 | cam2world_pose = LookAtPoseSampler.sample(pose_cond_rad, 3.14/2, camera_lookat_point, radius=2.7, device=device) |
| 91 | intrinsics = torch.tensor([[4.2647, 0, 0.5], [0, 4.2647, 0.5], [0, 0, 1]], device=device) |
| 92 | c = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) |
| 93 | c = c.repeat(len(zs), 1) |
| 94 | ws = G.mapping(z=zs, c=c, truncation_psi=psi, truncation_cutoff=truncation_cutoff) |
| 95 | _ = G.synthesis(ws[:1], c[:1]) # warm up |
| 96 | ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) |
| 97 | |
| 98 | # Interpolation. |
| 99 | grid = [] |
| 100 | for yi in range(grid_h): |
| 101 | row = [] |
| 102 | for xi in range(grid_w): |
| 103 | x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) |
| 104 | y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) |
| 105 | interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) |
| 106 | row.append(interp) |
| 107 | grid.append(row) |
| 108 | |
| 109 | # Render video. |
| 110 | max_batch = 10000000 |
| 111 | voxel_resolution = 512 |
| 112 | video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs) |
| 113 | |
| 114 | if gen_shapes: |
| 115 | outdir = 'interpolation_{}_{}/'.format(all_seeds[0], all_seeds[1]) |
| 116 | os.makedirs(outdir, exist_ok=True) |
| 117 | all_poses = [] |
| 118 | for frame_idx in tqdm(range(num_keyframes * w_frames)): |
| 119 | imgs = [] |
| 120 | for yi in range(grid_h): |
| 121 | for xi in range(grid_w): |
| 122 | if cfg == "Head": |
| 123 | pitch_range = 0.5 |
| 124 | cam2world_pose = LookAtPoseSampler.sample(3.14/2 + 2 * 3.14 * frame_idx / (num_keyframes * w_frames), 3.14/2 -0.05 + pitch_range * np.sin(2 * 3.14 * frame_idx / (num_keyframes * w_frames)), |
| 125 | camera_lookat_point, radius=2.7, device=device) |
| 126 | else: |
no test coverage detected