(conf)
| 200 | |
| 201 | |
| 202 | def train(conf): |
| 203 | logger = util.Logger(conf) |
| 204 | if not os.path.exists(conf.checkpoint_dir): |
| 205 | os.makedirs(conf.checkpoint_dir) |
| 206 | |
| 207 | model_name = conf.model_name |
| 208 | dataset_name = "ClassificationDataset" |
| 209 | collate_name = "FastTextCollator" if model_name == "FastText" \ |
| 210 | else "ClassificationCollator" |
| 211 | train_data_loader, validate_data_loader, test_data_loader = \ |
| 212 | get_data_loader(dataset_name, collate_name, conf) |
| 213 | empty_dataset = globals()[dataset_name](conf, [], mode="train") |
| 214 | model = get_classification_model(model_name, empty_dataset, conf) |
| 215 | loss_fn = globals()["ClassificationLoss"]( |
| 216 | label_size=len(empty_dataset.label_map), loss_type=conf.train.loss_type) |
| 217 | optimizer = get_optimizer(conf, model) |
| 218 | evaluator = cEvaluator(conf.eval.dir) |
| 219 | trainer = globals()["ClassificationTrainer"]( |
| 220 | empty_dataset.label_map, logger, evaluator, conf, loss_fn) |
| 221 | |
| 222 | best_epoch = -1 |
| 223 | best_performance = 0 |
| 224 | model_file_prefix = conf.checkpoint_dir + "/" + model_name |
| 225 | for epoch in range(conf.train.start_epoch, |
| 226 | conf.train.start_epoch + conf.train.num_epochs): |
| 227 | start_time = time.time() |
| 228 | trainer.train(train_data_loader, model, optimizer, "Train", epoch) |
| 229 | trainer.eval(train_data_loader, model, optimizer, "Train", epoch) |
| 230 | performance = trainer.eval( |
| 231 | validate_data_loader, model, optimizer, "Validate", epoch) |
| 232 | trainer.eval(test_data_loader, model, optimizer, "test", epoch) |
| 233 | if performance > best_performance: # record the best model |
| 234 | best_epoch = epoch |
| 235 | best_performance = performance |
| 236 | save_checkpoint({ |
| 237 | 'epoch': epoch, |
| 238 | 'model_name': model_name, |
| 239 | 'state_dict': model.state_dict(), |
| 240 | 'best_performance': best_performance, |
| 241 | 'optimizer': optimizer.state_dict(), |
| 242 | }, model_file_prefix) |
| 243 | time_used = time.time() - start_time |
| 244 | logger.info("Epoch %d cost time: %d second" % (epoch, time_used)) |
| 245 | |
| 246 | # best model on validateion set |
| 247 | best_epoch_file_name = model_file_prefix + "_" + str(best_epoch) |
| 248 | best_file_name = model_file_prefix + "_best" |
| 249 | shutil.copyfile(best_epoch_file_name, best_file_name) |
| 250 | |
| 251 | load_checkpoint(model_file_prefix + "_" + str(best_epoch), conf, model, |
| 252 | optimizer) |
| 253 | trainer.eval(test_data_loader, model, optimizer, "Best test", best_epoch) |
| 254 | |
| 255 | |
| 256 | if __name__ == '__main__': |
no test coverage detected