MCPcopy
hub / github.com/PaddlePaddle/PaddleNLP / eval

Function eval

examples/language_model/elmo/run_eval.py:26–81  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

24
25@paddle.no_grad()
26def eval(args):
27 paddle.set_device(args.device)
28
29 if not args.init_from_ckpt:
30 raise ValueError("init_from_ckpt should be set when eval.")
31 vocab = load_vocab(args.vocab_file, args.max_characters_per_token)
32
33 elmo = ELMo(
34 args.batch_size,
35 args.char_embed_dim,
36 args.projection_dim,
37 vocab.size,
38 dropout=args.dropout,
39 num_layers=args.num_layers,
40 num_highways=args.num_highways,
41 char_vocab_size=vocab.char_size,
42 )
43 elmo.eval()
44
45 elmo_loss = ELMoLoss()
46
47 # Loads pre-trained parameters.
48 weight_state_dict = paddle.load(args.init_from_ckpt + ".pdparams")
49 elmo.set_state_dict(weight_state_dict)
50 print("Loaded checkpoint from %s" % args.init_from_ckpt)
51
52 dev_dataset = OneBillionWordDataset(
53 args.dev_data_path, vocab, args.batch_size, args.unroll_steps, mode="test", shuffle=False, seed=args.seed
54 )
55
56 dev_dataloader = DataLoader(dev_dataset, return_list=True, batch_size=None)
57
58 total_step = total_loss = 0
59 total_time = 0.0
60 batch_start_time = time.time()
61 for step, inputs in enumerate(dev_dataloader, start=1):
62 ids, next_ids, ids_reverse, next_ids_reverse = inputs
63 outputs = elmo([ids, ids_reverse])
64 loss = elmo_loss(outputs, [next_ids, next_ids_reverse])
65 ppl = paddle.exp(loss)
66
67 total_loss += float(loss)
68 total_step += 1
69
70 total_time += time.time() - batch_start_time
71 if step % args.log_freq == 0:
72 print(
73 "Eval step %d - loss: %.4f - Perplexity: %.4f - %.3fs/step"
74 % (step, float(loss) * args.unroll_steps, float(ppl), total_time / args.log_freq)
75 )
76 total_time = 0.0
77 batch_start_time = time.time()
78
79 avg_loss = total_loss / total_step
80 avg_ppl = math.exp(avg_loss)
81 print("Eval - average loss: %.4f - average Perplexity: %.4f" % (avg_loss * args.unroll_steps, avg_ppl))
82
83

Callers 15

run_eval.pyFile · 0.70
create_hcgFunction · 0.50
run_system.pyFile · 0.50
rpc_client.pyFile · 0.50
recall_resultFunction · 0.50
rpc_client.pyFile · 0.50
RunMethod · 0.50
RunMethod · 0.50
preprocessMethod · 0.50
run_system.pyFile · 0.50
rpc_client.pyFile · 0.50
preprocessMethod · 0.50

Calls 8

load_vocabFunction · 0.90
ELMoClass · 0.90
ELMoLossClass · 0.90
DataLoaderClass · 0.85
evalMethod · 0.45
loadMethod · 0.45
set_state_dictMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…