| 29 | raise NotImplementedError(name) |
| 30 | |
| 31 | def get_models(args): |
| 32 | if 'LatteIMG' in args.model: |
| 33 | return LatteIMG_models[args.model]( |
| 34 | input_size=args.latent_size, |
| 35 | num_classes=args.num_classes, |
| 36 | num_frames=args.num_frames, |
| 37 | learn_sigma=args.learn_sigma, |
| 38 | extras=args.extras |
| 39 | ) |
| 40 | elif 'LatteT2V' in args.model: |
| 41 | return LatteT2V.from_pretrained(args.pretrained_model_path, subfolder="transformer", video_length=args.video_length) |
| 42 | elif 'Latte' in args.model: |
| 43 | return Latte_models[args.model]( |
| 44 | input_size=args.latent_size, |
| 45 | num_classes=args.num_classes, |
| 46 | num_frames=args.num_frames, |
| 47 | learn_sigma=args.learn_sigma, |
| 48 | extras=args.extras |
| 49 | ) |
| 50 | else: |
| 51 | raise '{} Model Not Supported!'.format(args.model) |
| 52 | |