MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / build_trainer

Function build_trainer

LanguageNetwork/BERT/models/trainer.py:21–60  ·  view source on GitHub ↗

Simplify `Trainer` creation based on user `opt`s* Args: opt (:obj:`Namespace`): user options (usually from argument parsing) model (:obj:`onmt.models.NMTModel`): the models to train fields (dict): dict of fields optim (:obj:`onmt.utils.Optimizer`): optimizer

(args, device_id, model, optim)

Source from the content-addressed store, hash-verified

19
20
21def build_trainer(args, device_id, model, optim):
22 """
23 Simplify `Trainer` creation based on user `opt`s*
24 Args:
25 opt (:obj:`Namespace`): user options (usually from argument parsing)
26 model (:obj:`onmt.models.NMTModel`): the models to train
27 fields (dict): dict of fields
28 optim (:obj:`onmt.utils.Optimizer`): optimizer used during training
29 data_type (str): string describing the type of data
30 e.g. "text", "img", "audio"
31 model_saver(:obj:`onmt.models.ModelSaverBase`): the utility object
32 used to save the models
33 """
34 device = "cpu" if args.visible_gpus == '-1' else "cuda"
35
36 grad_accum_count = args.accum_count
37 n_gpu = args.world_size
38
39 if device_id >= 0:
40 gpu_rank = int(args.gpu_ranks[device_id])
41 else:
42 gpu_rank = 0
43 n_gpu = 0
44
45 print('gpu_rank %d' % gpu_rank)
46
47 tensorboard_log_dir = args.model_path
48
49 writer = SummaryWriter(tensorboard_log_dir, comment="Unmt")
50
51 report_manager = ReportMgr(args.report_every, start_time=-1, tensorboard_writer=writer)
52
53 trainer = Trainer(args, model, optim, grad_accum_count, n_gpu, gpu_rank, report_manager)
54
55 # print(tr)
56 if model:
57 n_params = _tally_parameters(model)
58 logger.info('* number of parameters: %d' % n_params)
59
60 return trainer
61
62
63class Trainer(object):

Callers 6

predictMethod · 0.90
baselineMethod · 0.90
trainMethod · 0.90
validateMethod · 0.90
testMethod · 0.90
gen_features_vectorMethod · 0.90

Calls 3

ReportMgrClass · 0.90
TrainerClass · 0.85
_tally_parametersFunction · 0.85

Tested by

no test coverage detected