(model_name, dataset, conf)
| 46 | |
| 47 | |
| 48 | def get_classification_model(model_name, dataset, conf): |
| 49 | model = globals()[model_name](dataset, conf) |
| 50 | model = model.cuda(conf.device) if conf.device.startswith("cuda") else model |
| 51 | return model |
| 52 | |
| 53 | |
| 54 | def load_checkpoint(file_name, conf, model, optimizer): |