(self, batch)
| 23 | return StereoSetDataset(join(self.config.path, relative_path), self.config) |
| 24 | |
| 25 | def predict_single_batch(self, batch) -> List[int]: |
| 26 | log_probs = self.model.cond_log_prob(batch) |
| 27 | normalize_log_probs = [] |
| 28 | for origin_datas, predicts in zip(batch.get("choices"), log_probs): |
| 29 | normalize_log_probs_single = [] |
| 30 | for origin_data, predict in zip(origin_datas, predicts): |
| 31 | normalize_log_probs_single.append(predict / len(origin_data)) |
| 32 | normalize_log_probs.append(normalize_log_probs_single) |
| 33 | return [np.argmax(log_probs_single).item() for log_probs_single in normalize_log_probs] |
| 34 | |
| 35 | def report_group_metrics(self, group_name, result_dict_group: Dict[str, Tuple[Dict[str, float], int]], level=1): |
| 36 | for tmp1 in result_dict_group.values(): |
nothing calls this directly
no test coverage detected