Simplify `Trainer` creation based on user `opt`s* Args: opt (:obj:`Namespace`): user options (usually from argument parsing) model (:obj:`onmt.models.NMTModel`): the models to train fields (dict): dict of fields optim (:obj:`onmt.utils.Optimizer`): optimizer
(args, device_id, model, optim)
| 19 | |
| 20 | |
| 21 | def build_trainer(args, device_id, model, optim): |
| 22 | """ |
| 23 | Simplify `Trainer` creation based on user `opt`s* |
| 24 | Args: |
| 25 | opt (:obj:`Namespace`): user options (usually from argument parsing) |
| 26 | model (:obj:`onmt.models.NMTModel`): the models to train |
| 27 | fields (dict): dict of fields |
| 28 | optim (:obj:`onmt.utils.Optimizer`): optimizer used during training |
| 29 | data_type (str): string describing the type of data |
| 30 | e.g. "text", "img", "audio" |
| 31 | model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object |
| 32 | used to save the models |
| 33 | """ |
| 34 | device = "cpu" if args.visible_gpus == '-1' else "cuda" |
| 35 | |
| 36 | grad_accum_count = args.accum_count |
| 37 | n_gpu = args.world_size |
| 38 | |
| 39 | if device_id >= 0: |
| 40 | gpu_rank = int(args.gpu_ranks[device_id]) |
| 41 | else: |
| 42 | gpu_rank = 0 |
| 43 | n_gpu = 0 |
| 44 | |
| 45 | print('gpu_rank %d' % gpu_rank) |
| 46 | |
| 47 | tensorboard_log_dir = args.model_path |
| 48 | |
| 49 | writer = SummaryWriter(tensorboard_log_dir, comment="Unmt") |
| 50 | |
| 51 | report_manager = ReportMgr(args.report_every, start_time=-1, tensorboard_writer=writer) |
| 52 | |
| 53 | trainer = Trainer(args, model, optim, grad_accum_count, n_gpu, gpu_rank, report_manager) |
| 54 | |
| 55 | # print(tr) |
| 56 | if model: |
| 57 | n_params = _tally_parameters(model) |
| 58 | logger.info('* number of parameters: %d' % n_params) |
| 59 | |
| 60 | return trainer |
| 61 | |
| 62 | |
| 63 | class Trainer(object): |