(model_choice: str)
| 15 | |
| 16 | |
| 17 | def load_model_if_needed(model_choice: str): |
| 18 | global CURRENT_MODEL_TYPE, CURRENT_MODEL |
| 19 | if CURRENT_MODEL_TYPE != model_choice: |
| 20 | if CURRENT_MODEL is not None: |
| 21 | del CURRENT_MODEL |
| 22 | torch.cuda.empty_cache() |
| 23 | print(f"Loading {model_choice} model...") |
| 24 | CURRENT_MODEL = Zonos.from_pretrained(model_choice, device=device) |
| 25 | CURRENT_MODEL.requires_grad_(False).eval() |
| 26 | CURRENT_MODEL_TYPE = model_choice |
| 27 | print(f"{model_choice} model loaded successfully!") |
| 28 | return CURRENT_MODEL |
| 29 | |
| 30 | |
| 31 | def update_ui(model_choice): |
no test coverage detected