(temp_dir, cand, ref)
| 52 | |
| 53 | |
| 54 | def test_rouge(temp_dir, cand, ref): |
| 55 | candidates = [line.strip() for line in open(cand, encoding='utf-8')] |
| 56 | references = [line.strip() for line in open(ref, encoding='utf-8')] |
| 57 | print(len(candidates)) |
| 58 | print(len(references)) |
| 59 | assert len(candidates) == len(references) |
| 60 | |
| 61 | cnt = len(candidates) |
| 62 | current_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) |
| 63 | tmp_dir = os.path.join(temp_dir, "rouge-tmp-{}".format(current_time)) |
| 64 | if not os.path.isdir(tmp_dir): |
| 65 | os.mkdir(tmp_dir) |
| 66 | os.mkdir(tmp_dir + "/candidate") |
| 67 | os.mkdir(tmp_dir + "/reference") |
| 68 | try: |
| 69 | |
| 70 | for i in range(cnt): |
| 71 | if len(references[i]) < 1: |
| 72 | continue |
| 73 | with open(tmp_dir + "/candidate/cand.{}.txt".format(i), "w", |
| 74 | encoding="utf-8") as f: |
| 75 | f.write(candidates[i]) |
| 76 | with open(tmp_dir + "/reference/ref.{}.txt".format(i), "w", |
| 77 | encoding="utf-8") as f: |
| 78 | f.write(references[i]) |
| 79 | r = pyrouge.Rouge155(temp_dir=temp_dir) |
| 80 | r.model_dir = tmp_dir + "/reference/" |
| 81 | r.system_dir = tmp_dir + "/candidate/" |
| 82 | r.model_filename_pattern = 'ref.#ID#.txt' |
| 83 | r.system_filename_pattern = r'cand.(\d+).txt' |
| 84 | rouge_results = r.convert_and_evaluate() |
| 85 | print(rouge_results) |
| 86 | results_dict = r.output_to_dict(rouge_results) |
| 87 | finally: |
| 88 | pass |
| 89 | if os.path.isdir(tmp_dir): |
| 90 | shutil.rmtree(tmp_dir) |
| 91 | return results_dict |
| 92 | |
| 93 | |
| 94 | def rouge_results_to_str(results_dict): |
no test coverage detected