(self)
| 14 | |
| 15 | class TestModuleMetrics(unittest.TestCase): |
| 16 | def test_caption_bleu4(self): |
| 17 | path = os.path.join( |
| 18 | os.path.abspath(__file__), |
| 19 | "../../../pythia/common/defaults/configs/tasks/captioning/coco.yml", |
| 20 | ) |
| 21 | with open(os.path.abspath(path)) as f: |
| 22 | config = yaml.load(f, Loader=yaml.FullLoader) |
| 23 | |
| 24 | config = ConfigNode(config) |
| 25 | captioning_config = config.task_attributes.captioning.dataset_attributes.coco |
| 26 | caption_processor_config = captioning_config.processors.caption_processor |
| 27 | vocab_path = os.path.join(os.path.abspath(__file__), "..", "..", "data", "vocab.txt") |
| 28 | caption_processor_config.params.vocab.vocab_file = os.path.abspath(vocab_path) |
| 29 | caption_processor = CaptionProcessor(caption_processor_config.params) |
| 30 | registry.register("coco_caption_processor", caption_processor) |
| 31 | |
| 32 | caption_bleu4 = metrics.CaptionBleu4Metric() |
| 33 | expected = Sample() |
| 34 | predicted = dict() |
| 35 | |
| 36 | # Test complete match |
| 37 | expected.answers = torch.empty((5, 5, 10)) |
| 38 | expected.answers.fill_(4) |
| 39 | predicted["scores"] = torch.zeros((5, 10, 19)) |
| 40 | predicted["scores"][:, :, 4] = 1.0 |
| 41 | |
| 42 | self.assertEqual(caption_bleu4.calculate(expected, predicted).item(), 1.0) |
| 43 | |
| 44 | # Test partial match |
| 45 | expected.answers = torch.empty((5, 5, 10)) |
| 46 | expected.answers.fill_(4) |
| 47 | predicted["scores"] = torch.zeros((5, 10, 19)) |
| 48 | predicted["scores"][:, 0:5, 4] = 1.0 |
| 49 | |
| 50 | self.assertAlmostEqual( |
| 51 | caption_bleu4.calculate(expected, predicted).item(), 0.3928, 4 |
| 52 | ) |
nothing calls this directly
no test coverage detected