(self, config)
| 43 | |
| 44 | class Predictor(object): |
| 45 | def __init__(self, config): |
| 46 | self.config = config |
| 47 | self.model_name = config.model_name |
| 48 | self.use_cuda = config.device.startswith("cuda") |
| 49 | self.dataset_name = "ClassificationDataset" |
| 50 | self.collate_name = "FastTextCollator" if self.model_name == "FastText" \ |
| 51 | else "ClassificationCollator" |
| 52 | self.dataset = globals()[self.dataset_name](config, [], mode="infer") |
| 53 | self.collate_fn = globals()[self.collate_name](config, len(self.dataset.label_map)) |
| 54 | self.model = Predictor._get_classification_model(self.model_name, self.dataset, config) |
| 55 | Predictor._load_checkpoint(config.eval.model_dir, self.model, self.use_cuda) |
| 56 | self.model.eval() |
| 57 | |
| 58 | @staticmethod |
| 59 | def _get_classification_model(model_name, dataset, conf): |
nothing calls this directly
no test coverage detected