MCPcopy
hub / github.com/Tencent/NeuralNLP-NeuralClassifier / __init__

Method __init__

predict.py:45–56  ·  view source on GitHub ↗
(self, config)

Source from the content-addressed store, hash-verified

43
44class Predictor(object):
45 def __init__(self, config):
46 self.config = config
47 self.model_name = config.model_name
48 self.use_cuda = config.device.startswith("cuda")
49 self.dataset_name = "ClassificationDataset"
50 self.collate_name = "FastTextCollator" if self.model_name == "FastText" \
51 else "ClassificationCollator"
52 self.dataset = globals()[self.dataset_name](config, [], mode="infer")
53 self.collate_fn = globals()[self.collate_name](config, len(self.dataset.label_map))
54 self.model = Predictor._get_classification_model(self.model_name, self.dataset, config)
55 Predictor._load_checkpoint(config.eval.model_dir, self.model, self.use_cuda)
56 self.model.eval()
57
58 @staticmethod
59 def _get_classification_model(model_name, dataset, conf):

Callers

nothing calls this directly

Calls 3

_load_checkpointMethod · 0.80
evalMethod · 0.80

Tested by

no test coverage detected