(gts, res)
| 25 | |
| 26 | |
| 27 | def compute_caption(gts, res): |
| 28 | preds_str = res |
| 29 | references = gts |
| 30 | tokenizer = CocoTokenizer(preds_str, references) |
| 31 | res, gts = tokenizer.tokenize() |
| 32 | |
| 33 | scorers = [ |
| 34 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]), |
| 35 | (Meteor(), "METEOR"), |
| 36 | (Rouge(), "ROUGE_L"), |
| 37 | (Cider(), "CIDEr"), |
| 38 | (Spice(), "SPICE") |
| 39 | ] |
| 40 | f_res = {} |
| 41 | for scorer, method in scorers: |
| 42 | print('computing %s score...' % (scorer.method())) |
| 43 | |
| 44 | score, scores = scorer.compute_score(gts, res) |
| 45 | if type(method) == list: |
| 46 | for sc, scs, m in zip(score, scores, method): |
| 47 | print("%s: %0.3f" % (m, sc)) |
| 48 | f_res[m] = sc |
| 49 | else: |
| 50 | print("%s: %0.3f" % (method, score)) |
| 51 | f_res[method] = score |
| 52 | f_res["SPIDEr"] = (f_res['CIDEr']+f_res['SPICE']) / 2. |
| 53 | return f_res |
| 54 | |
| 55 | |
| 56 | class AudioDataset(torch.utils.data.Dataset): |
no test coverage detected