Get classification model from configuration
(model_name, dataset, conf)
| 77 | |
| 78 | |
| 79 | def get_classification_model(model_name, dataset, conf): |
| 80 | """Get classification model from configuration |
| 81 | """ |
| 82 | model = globals()[model_name](dataset, conf) |
| 83 | model = model.cuda(conf.device) if conf.device.startswith("cuda") else model |
| 84 | return model |
| 85 | |
| 86 | |
| 87 | class ClassificationTrainer(object): |