()
| 27 | |
| 28 | |
| 29 | def main(): |
| 30 | parser = argparse.ArgumentParser(description="Decode shape latent to GLB and collect renders") |
| 31 | parser.add_argument("--root", type=str, required=True, help="Dataset root, e.g. /local-ssd/datasets/ObjaverseXL_sketchfab") |
| 32 | parser.add_argument("--sha256", type=str, required=True, help="SHA256 of the asset") |
| 33 | parser.add_argument("--resolution", type=int, default=1024, help="Decoder resolution (must match latent resolution)") |
| 34 | parser.add_argument("--view_idx", type=int, default=0, help="View index to decode") |
| 35 | parser.add_argument("--latent_name", type=str, default="shape_enc_next_dc_f16c32_fp16_1024_view", |
| 36 | help="Latent directory name under shape_latents/") |
| 37 | parser.add_argument("--decoder", type=str, default="microsoft/TRELLIS.2-4B/ckpts/shape_dec_next_dc_f16c32_fp16", |
| 38 | help="Pretrained shape decoder path (HuggingFace or local)") |
| 39 | parser.add_argument("--output_dir", type=str, default=None, help="Output directory (default: <root>/vis/<sha256>)") |
| 40 | args = parser.parse_args() |
| 41 | |
| 42 | sha256 = args.sha256 |
| 43 | root = args.root |
| 44 | view_idx = args.view_idx |
| 45 | |
| 46 | # Paths |
| 47 | latent_dir = os.path.join(root, "shape_latents", args.latent_name, sha256) |
| 48 | latent_file = os.path.join(latent_dir, f"view{view_idx:02d}.npz") |
| 49 | scale_file = os.path.join(latent_dir, f"view{view_idx:02d}_scale.json") |
| 50 | renders_dir = os.path.join(root, "renders_cond", sha256) |
| 51 | output_dir = args.output_dir or os.path.join(root, "vis", sha256) |
| 52 | |
| 53 | # Validate |
| 54 | assert os.path.exists(latent_file), f"Latent file not found: {latent_file}" |
| 55 | print(f"[Input] Latent: {latent_file}") |
| 56 | if os.path.exists(scale_file): |
| 57 | print(f"[Input] Scale: {scale_file}") |
| 58 | if os.path.exists(renders_dir): |
| 59 | print(f"[Input] Renders: {renders_dir}") |
| 60 | |
| 61 | # 1. Load latent |
| 62 | print("[Step 1] Loading shape latent...") |
| 63 | data = np.load(latent_file) |
| 64 | coords = torch.tensor(data['coords']).int() |
| 65 | feats = torch.tensor(data['feats']).float() |
| 66 | # Prepend batch dim (0) to coords |
| 67 | coords = torch.cat([torch.zeros_like(coords[:, :1]), coords], dim=1) |
| 68 | slat = sp.SparseTensor(feats.cuda(), coords.cuda()) |
| 69 | print(f" coords: {coords.shape}, feats: {feats.shape}") |
| 70 | |
| 71 | # 2. Load decoder |
| 72 | print(f"[Step 2] Loading shape decoder: {args.decoder}") |
| 73 | decoder = models.from_pretrained(args.decoder) |
| 74 | decoder.set_resolution(args.resolution) |
| 75 | decoder = decoder.cuda().eval() |
| 76 | |
| 77 | # 3. Decode |
| 78 | print("[Step 3] Decoding shape latent → mesh...") |
| 79 | with torch.no_grad(): |
| 80 | meshes, subs = decoder(slat, return_subs=True) |
| 81 | mesh = meshes[0] |
| 82 | print(f" vertices: {mesh.vertices.shape}, faces: {mesh.faces.shape}") |
| 83 | |
| 84 | # 4. Convert to trimesh and export GLB |
| 85 | print("[Step 4] Exporting GLB...") |
| 86 | vertices = mesh.vertices.cpu().numpy() |
no test coverage detected