MCPcopy Index your code
hub / github.com/THUDM/GLM / evaluate

Method evaluate

tasks/seq2seq/evaluate.py:267–390  ·  view source on GitHub ↗

Calculate correct over total answers and return prediction if the `output_predictions` is true.

(self, model, dataloader, example_dict, args)

Source from the content-addressed store, hash-verified

265 self.processors.append(processor)
266
267 def evaluate(self, model, dataloader, example_dict, args):
268 """Calculate correct over total answers and return prediction if the
269 `output_predictions` is true."""
270 model.eval()
271 local_predictions = {}
272 print_rank_0("Distributed store created")
273 with torch.no_grad():
274 # For all the batches in the dataset.
275 for idx, data in enumerate(dataloader):
276 tokens, attention_mask, position_ids = process_batch(data, args)
277 batch_size = tokens.size(0)
278 beam_scorer = BeamSearchScorer(
279 batch_size=batch_size,
280 max_length=args.out_seq_length,
281 num_beams=args.num_beams,
282 device=tokens.device,
283 length_penalty=args.length_penalty,
284 do_early_stopping=False,
285 )
286 beam_scores = torch.zeros((batch_size, args.num_beams), dtype=torch.float, device=tokens.device)
287 beam_scores[:, 1:] = -1e9
288 beam_scores = beam_scores.view((batch_size * args.num_beams,))
289 # Run the model forward.
290 counter = 0
291 context_length = tokens.size(1)
292 while counter < args.tgt_seq_length:
293 if counter == 0:
294 next_token_logits, *mems = model(tokens, position_ids, attention_mask, return_memory=True)
295 seq_length = next_token_logits.size(1)
296 next_token_logits = next_token_logits[:, -1]
297 next_token_logits = next_token_logits.unsqueeze(1).repeat(1, args.num_beams, 1).view(
298 batch_size * args.num_beams, -1)
299 mems = [mem.unsqueeze(1).repeat(1, args.num_beams, 1, 1).view(batch_size * args.num_beams,
300 seq_length, -1) for mem in mems]
301 position_ids = tokens.new_ones(batch_size, args.num_beams, 2, 1)
302 for i, text in enumerate(tokens.tolist()):
303 mask_pos = text.index(self.mask_token)
304 position_ids[i, :, 0] = mask_pos
305 position_ids = position_ids.reshape(batch_size * args.num_beams, 2, 1)
306 tokens = tokens.new_zeros(batch_size * args.num_beams, 0)
307 else:
308 if not args.no_block_position:
309 position_ids[:, 1] = counter + 1
310 last_token = tokens[:, -1:]
311 if self.mask_pad_token:
312 cur_attention_mask = attention_mask[:, :, -1:, :].unsqueeze(1).expand(-1, args.num_beams, -1,
313 -1, -1).reshape(
314 batch_size * args.num_beams, 1, 1, context_length)
315 cur_attention_mask = torch.cat(
316 (cur_attention_mask, attention_mask.new_ones((batch_size * args.num_beams, 1, 1, counter))),
317 dim=-1)
318 else:
319 cur_attention_mask = tokens.new_zeros([batch_size * args.num_beams])
320 next_token_logits, *mems = model(last_token, position_ids, cur_attention_mask, *mems,
321 return_memory=True)
322 next_token_logits = next_token_logits[:, -1]
323 next_token_logits = top_k_logits(next_token_logits, top_k=args.top_k, top_p=args.top_p)
324 next_token_scores = F.log_softmax(next_token_logits, dim=-1)

Callers

nothing calls this directly

Calls 9

processMethod · 0.95
finalizeMethod · 0.95
print_rank_0Function · 0.90
BeamSearchScorerClass · 0.90
top_k_logitsFunction · 0.90
squad_decodeFunction · 0.85
appendMethod · 0.80
process_batchFunction · 0.70
DecodeIdsMethod · 0.45

Tested by

no test coverage detected