MCPcopy
hub / github.com/Tencent/NeuralNLP-NeuralClassifier / train

Function train

train.py:202–253  ·  view source on GitHub ↗
(conf)

Source from the content-addressed store, hash-verified

200
201
202def train(conf):
203 logger = util.Logger(conf)
204 if not os.path.exists(conf.checkpoint_dir):
205 os.makedirs(conf.checkpoint_dir)
206
207 model_name = conf.model_name
208 dataset_name = "ClassificationDataset"
209 collate_name = "FastTextCollator" if model_name == "FastText" \
210 else "ClassificationCollator"
211 train_data_loader, validate_data_loader, test_data_loader = \
212 get_data_loader(dataset_name, collate_name, conf)
213 empty_dataset = globals()[dataset_name](conf, [], mode="train")
214 model = get_classification_model(model_name, empty_dataset, conf)
215 loss_fn = globals()["ClassificationLoss"](
216 label_size=len(empty_dataset.label_map), loss_type=conf.train.loss_type)
217 optimizer = get_optimizer(conf, model)
218 evaluator = cEvaluator(conf.eval.dir)
219 trainer = globals()["ClassificationTrainer"](
220 empty_dataset.label_map, logger, evaluator, conf, loss_fn)
221
222 best_epoch = -1
223 best_performance = 0
224 model_file_prefix = conf.checkpoint_dir + "/" + model_name
225 for epoch in range(conf.train.start_epoch,
226 conf.train.start_epoch + conf.train.num_epochs):
227 start_time = time.time()
228 trainer.train(train_data_loader, model, optimizer, "Train", epoch)
229 trainer.eval(train_data_loader, model, optimizer, "Train", epoch)
230 performance = trainer.eval(
231 validate_data_loader, model, optimizer, "Validate", epoch)
232 trainer.eval(test_data_loader, model, optimizer, "test", epoch)
233 if performance > best_performance: # record the best model
234 best_epoch = epoch
235 best_performance = performance
236 save_checkpoint({
237 'epoch': epoch,
238 'model_name': model_name,
239 'state_dict': model.state_dict(),
240 'best_performance': best_performance,
241 'optimizer': optimizer.state_dict(),
242 }, model_file_prefix)
243 time_used = time.time() - start_time
244 logger.info("Epoch %d cost time: %d second" % (epoch, time_used))
245
246 # best model on validateion set
247 best_epoch_file_name = model_file_prefix + "_" + str(best_epoch)
248 best_file_name = model_file_prefix + "_best"
249 shutil.copyfile(best_epoch_file_name, best_file_name)
250
251 load_checkpoint(model_file_prefix + "_" + str(best_epoch), conf, model,
252 optimizer)
253 trainer.eval(test_data_loader, model, optimizer, "Best test", best_epoch)
254
255
256if __name__ == '__main__':

Callers 1

train.pyFile · 0.85

Calls 8

infoMethod · 0.95
get_optimizerFunction · 0.90
get_data_loaderFunction · 0.85
save_checkpointFunction · 0.85
trainMethod · 0.80
evalMethod · 0.80
get_classification_modelFunction · 0.70
load_checkpointFunction · 0.70

Tested by

no test coverage detected