(args)
| 24 | |
| 25 | @paddle.no_grad() |
| 26 | def eval(args): |
| 27 | paddle.set_device(args.device) |
| 28 | |
| 29 | if not args.init_from_ckpt: |
| 30 | raise ValueError("init_from_ckpt should be set when eval.") |
| 31 | vocab = load_vocab(args.vocab_file, args.max_characters_per_token) |
| 32 | |
| 33 | elmo = ELMo( |
| 34 | args.batch_size, |
| 35 | args.char_embed_dim, |
| 36 | args.projection_dim, |
| 37 | vocab.size, |
| 38 | dropout=args.dropout, |
| 39 | num_layers=args.num_layers, |
| 40 | num_highways=args.num_highways, |
| 41 | char_vocab_size=vocab.char_size, |
| 42 | ) |
| 43 | elmo.eval() |
| 44 | |
| 45 | elmo_loss = ELMoLoss() |
| 46 | |
| 47 | # Loads pre-trained parameters. |
| 48 | weight_state_dict = paddle.load(args.init_from_ckpt + ".pdparams") |
| 49 | elmo.set_state_dict(weight_state_dict) |
| 50 | print("Loaded checkpoint from %s" % args.init_from_ckpt) |
| 51 | |
| 52 | dev_dataset = OneBillionWordDataset( |
| 53 | args.dev_data_path, vocab, args.batch_size, args.unroll_steps, mode="test", shuffle=False, seed=args.seed |
| 54 | ) |
| 55 | |
| 56 | dev_dataloader = DataLoader(dev_dataset, return_list=True, batch_size=None) |
| 57 | |
| 58 | total_step = total_loss = 0 |
| 59 | total_time = 0.0 |
| 60 | batch_start_time = time.time() |
| 61 | for step, inputs in enumerate(dev_dataloader, start=1): |
| 62 | ids, next_ids, ids_reverse, next_ids_reverse = inputs |
| 63 | outputs = elmo([ids, ids_reverse]) |
| 64 | loss = elmo_loss(outputs, [next_ids, next_ids_reverse]) |
| 65 | ppl = paddle.exp(loss) |
| 66 | |
| 67 | total_loss += float(loss) |
| 68 | total_step += 1 |
| 69 | |
| 70 | total_time += time.time() - batch_start_time |
| 71 | if step % args.log_freq == 0: |
| 72 | print( |
| 73 | "Eval step %d - loss: %.4f - Perplexity: %.4f - %.3fs/step" |
| 74 | % (step, float(loss) * args.unroll_steps, float(ppl), total_time / args.log_freq) |
| 75 | ) |
| 76 | total_time = 0.0 |
| 77 | batch_start_time = time.time() |
| 78 | |
| 79 | avg_loss = total_loss / total_step |
| 80 | avg_ppl = math.exp(avg_loss) |
| 81 | print("Eval - average loss: %.4f - average Perplexity: %.4f" % (avg_loss * args.unroll_steps, avg_ppl)) |
| 82 | |
| 83 |
no test coverage detected
searching dependent graphs…