(prompt_model, dataloader)
| 98 | from openprompt.utils.metrics import generation_metric |
| 99 | # Define evaluate function |
| 100 | def evaluate(prompt_model, dataloader): |
| 101 | generated_sentence = [] |
| 102 | groundtruth_sentence = [] |
| 103 | prompt_model.eval() |
| 104 | |
| 105 | for step, inputs in enumerate(dataloader): |
| 106 | if use_cuda: |
| 107 | inputs = inputs.cuda() |
| 108 | _, output_sentence = prompt_model.generate(inputs, **generation_arguments) |
| 109 | generated_sentence.extend(output_sentence) |
| 110 | groundtruth_sentence.extend(inputs['tgt_text']) |
| 111 | score = generation_metric(generated_sentence, groundtruth_sentence, "sentence_bleu") |
| 112 | print("test_score", score, flush=True) |
| 113 | return generated_sentence |
| 114 | |
| 115 | |
| 116 |
no test coverage detected