MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / __init__

Method __init__

LanguageNetwork/BERT/predict.py:23–59  ·  view source on GitHub ↗

:param args: parser.parse_args() :param device_id: 0 or -1

(self, args, device_id)

Source from the content-addressed store, hash-verified

21 """Run Model"""
22
23 def __init__(self, args, device_id):
24 """
25 :param args: parser.parse_args()
26 :param device_id: 0 or -1
27 """
28 self.args = args
29 self.device_id = device_id
30 self.model_flags = ['hidden_size', 'ff_size', 'heads', 'inter_layers', 'encoder', 'ff_actv', 'use_interval',
31 'rnn_size']
32
33 self.device = "cpu" if self.args.visible_gpus == '-1' else "cuda"
34 logger.info('Device ID %d' % self.device_id)
35 logger.info('Device %s' % self.device)
36 torch.manual_seed(self.args.seed)
37 random.seed(self.args.seed)
38
39 if self.device_id >= 0:
40 torch.cuda.set_device(self.device_id)
41
42 init_logger(args.log_file)
43
44 try:
45 self.step = int(self.args.test_from.split('.')[-2].split('_')[-1])
46 except IndexError:
47 self.step = 0
48
49 logger.info('Loading checkpoint from %s' % self.args.test_from)
50 checkpoint = torch.load(self.args.test_from, map_location=lambda storage, loc: storage)
51 opt = vars(checkpoint['opt'])
52 for k in opt.keys():
53 if k in self.model_flags:
54 setattr(self.args, k, opt[k])
55
56 config = BertConfig.from_json_file(self.args.bert_config_path)
57 self.model = model_builder.Summarizer(self.args, self.device, load_pretrained_bert=False, bert_config=config)
58 self.model.load_cp(checkpoint)
59 self.model.eval()
60
61 def predict(self):
62

Callers

nothing calls this directly

Calls 4

init_loggerFunction · 0.90
load_cpMethod · 0.80
loadMethod · 0.45
from_json_fileMethod · 0.45

Tested by

no test coverage detected