MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / Running

Class Running

LanguageNetwork/BERT/train.py:115–285  ·  view source on GitHub ↗

Run Model

Source from the content-addressed store, hash-verified

113
114
115class Running(object):
116 """Run Model"""
117
118 def __init__(self, args, device_id):
119 """
120 :param args: parser.parse_args()
121 :param device_id: 0 or -1
122 """
123 self.args = args
124 self.device_id = device_id
125 self.model_flags = ['hidden_size', 'ff_size', 'heads', 'inter_layers', 'encoder', 'ff_actv', 'use_interval',
126 'rnn_size']
127
128 self.device = "cpu" if self.args.visible_gpus == '-1' else "cuda"
129 logger.info('Device ID %d' % self.device_id)
130 logger.info('Device %s' % self.device)
131 torch.manual_seed(self.args.seed)
132 random.seed(self.args.seed)
133
134 if self.device_id >= 0:
135 torch.cuda.set_device(self.device_id)
136
137 init_logger(args.log_file)
138
139 def baseline(self, cal_lead=False, cal_oracle=False):
140 test_iter = data_loader.DataLoader(self.args, data_loader.load_dataset(self.args, 'test', shuffle=False),
141 self.args.batch_size, self.device, shuffle=False, is_test=True)
142
143 trainer = build_trainer(self.args, self.device_id, None, None)
144
145 if cal_lead:
146 trainer.test(test_iter, 0, cal_lead=True)
147 elif cal_oracle:
148 trainer.test(test_iter, 0, cal_oracle=True)
149
150 def train_iter(self):
151 return data_loader.DataLoader(self.args, data_loader.load_dataset(self.args, 'train', shuffle=True),
152 self.args.batch_size, self.device, shuffle=True, is_test=False)
153
154 def train(self):
155 model = model_builder.Summarizer(self.args, self.device, load_pretrained_bert=True)
156
157 if self.args.train_from:
158 logger.info('Loading checkpoint from %s' % self.args.train_from)
159 checkpoint = torch.load(self.args.train_from, map_location=lambda storage, loc: storage)
160 opt = vars(checkpoint['opt'])
161 for k in opt.keys():
162 if k in self.model_flags:
163 setattr(self.args, k, opt[k])
164 model.load_cp(checkpoint)
165 optimizer = model_builder.build_optim(self.args, model, checkpoint)
166 else:
167 optimizer = model_builder.build_optim(self.args, model, None)
168
169 logger.info(model)
170 trainer = build_trainer(self.args, self.device_id, model, optimizer)
171 trainer.train(self.train_iter, self.args.train_steps)
172

Callers 2

multi_card_trainMethod · 0.70
train.pyFile · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected