| 348 | return model.eval().to(device) |
| 349 | |
| 350 | def load_tokenizers(args, device): |
| 351 | toks = {} |
| 352 | |
| 353 | # RGB tokenizer |
| 354 | if args.rgb_tok_id: |
| 355 | toks['tok_rgb'] = load_model(args.rgb_tok_id, DiVAE, device) |
| 356 | |
| 357 | # Optional RGB ControlNet |
| 358 | if args.controlnet_id: |
| 359 | toks['controlnet'] = load_model(args.controlnet_id, VQControlNet, device) |
| 360 | |
| 361 | # Depth tokenizer |
| 362 | if args.depth_tok_id: |
| 363 | toks['tok_depth'] = load_model(args.depth_tok_id, DiVAE, device) |
| 364 | |
| 365 | # Normal tokenizer |
| 366 | if args.normal_tok_id: |
| 367 | toks['tok_normal'] = load_model(args.normal_tok_id, DiVAE, device) |
| 368 | |
| 369 | # Edges tokenizer |
| 370 | if args.edges_tok_id: |
| 371 | toks['tok_canny_edge'] = load_model(args.edges_tok_id, DiVAE, device) |
| 372 | toks['tok_sam_edge'] = toks['tok_canny_edge'] |
| 373 | |
| 374 | # Semseg tokenizer |
| 375 | if args.semseg_tok_id: |
| 376 | toks['tok_semseg'] = load_model(args.semseg_tok_id, VQVAE, device) |
| 377 | |
| 378 | # CLIP tokenizer |
| 379 | if args.clip_tok_id: |
| 380 | toks['tok_clip'] = load_model(args.clip_tok_id, VQVAE, device) |
| 381 | |
| 382 | # DINOv2 tokenizer |
| 383 | if args.dinov2_tok_id: |
| 384 | toks['tok_dinov2'] = load_model(args.dinov2_tok_id, VQVAE, device) |
| 385 | |
| 386 | # ImageBind tokenizer |
| 387 | if args.imagebind_tok_id: |
| 388 | toks['tok_imagebind'] = load_model(args.imagebind_tok_id, VQVAE, device) |
| 389 | |
| 390 | # DINOv2 global tokenizer |
| 391 | if args.dinov2_glob_tok_id: |
| 392 | toks['tok_dinov2_global'] = load_model(args.dinov2_glob_tok_id, VQVAE, device) |
| 393 | |
| 394 | # ImageBind global tokenizer |
| 395 | if args.imagebind_glob_tok_id: |
| 396 | toks['tok_imagebind_global'] = load_model(args.imagebind_glob_tok_id, VQVAE, device) |
| 397 | |
| 398 | # SAM instances |
| 399 | if args.sam_instance_tok_id: |
| 400 | toks['sam_instance'] = load_model(args.sam_instance_tok_id, VQVAE, device) |
| 401 | |
| 402 | # Human poses |
| 403 | if args.human_poses_tok_id: |
| 404 | toks['tok_pose'] = load_model(args.human_poses_tok_id, VQVAE, device) |
| 405 | |
| 406 | return toks |
| 407 | |