MCPcopy
hub / github.com/jindongwang/transferlearning / train_maml_epoch

Function train_maml_epoch

code/ASR/Adapter/train.py:129–203  ·  view source on GitHub ↗
(dataloader, model, optimizer, epoch=None)

Source from the content-addressed store, hash-verified

127 return dict_average(stats)
128
129def train_maml_epoch(dataloader, model, optimizer, epoch=None):
130 model.train()
131 stats = collections.defaultdict(list)
132
133 for batch_idx, total_batches in enumerate(dataloader):
134 i = batch_idx # current iteration in epoch
135 len_dataloader = len(dataloader) # total iteration in epoch
136 meta_iters = args.epochs * len_dataloader
137 current_iter = float(i + (epoch - 1) * len_dataloader)
138 frac_done = 1.0 * float(current_iter) / meta_iters
139 current_outerstepsize = args.meta_lr * (1. - frac_done)
140
141 weights_original = copy.deepcopy(model.state_dict())
142 new_weights = []
143 for total_batch in total_batches: # Iter by languages
144 in_batch_size = int(total_batch[0].shape[0] / 2) # In-language batch size
145 for meta_step in range(2): # Meta-train & meta-valid
146 if meta_step == 1:
147 last_backup = copy.deepcopy(model.state_dict())
148 else:
149 last_backup = None
150 batch = list(copy.deepcopy(total_batch))
151 for i_batch in range(len(batch)-1):
152 batch[i_batch] = batch[i_batch][meta_step*in_batch_size:(1+meta_step)*in_batch_size]
153 batch = tuple(batch)
154
155 fbank, seq_lens, tokens, language = batch
156 fbank, seq_lens, tokens = fbank.cuda(), seq_lens.cuda(), tokens.cuda()
157 optimizer.zero_grad()
158 model.zero_grad()
159 if args.ngpu <= 1 or args.dist_train:
160 loss = model(fbank, seq_lens, tokens, language).mean() # / self.accum_grad
161 else:
162 # apex does not support torch.nn.DataParallel
163 loss = (
164 data_parallel(model, (fbank, seq_lens, tokens, language), range(args.ngpu)).mean() # / self.accum_grad
165 )
166 # print(loss.item())
167 loss.backward()
168 grad_norm = clip_grad_norm_(model.parameters(), args.grad_clip)
169 if math.isnan(grad_norm):
170 logging.warning("grad norm is nan. Do not update model.")
171 else:
172 optimizer.step()
173
174 if meta_step == 1: # Record meta valid
175 if not hasattr(model, "module"):
176 if hasattr(model, "acc") and model.acc is not None:
177 stats["acc_lst"].append(model.acc)
178 model.acc = None
179 else:
180 if hasattr(model, "acc") and model.module.acc is not None:
181 stats["acc_lst"].append(model.module.acc)
182 model.module.acc = None
183 stats["loss_lst"].append(loss.item())
184 stats["meta_lr"] = current_outerstepsize
185 optimizer.zero_grad()
186

Callers

nothing calls this directly

Calls 8

dict_averageFunction · 0.90
stepMethod · 0.80
trainMethod · 0.45
state_dictMethod · 0.45
meanMethod · 0.45
backwardMethod · 0.45
parametersMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected