(char_center, post_result, preds)
| 819 | |
| 820 | |
| 821 | def update_center(char_center, post_result, preds): |
| 822 | result, label = post_result |
| 823 | feats, logits = preds |
| 824 | logits = paddle.argmax(logits, axis=-1) |
| 825 | feats = feats.numpy() |
| 826 | logits = logits.numpy() |
| 827 | |
| 828 | for idx_sample in range(len(label)): |
| 829 | if result[idx_sample][0] == label[idx_sample][0]: |
| 830 | feat = feats[idx_sample] |
| 831 | logit = logits[idx_sample] |
| 832 | for idx_time in range(len(logit)): |
| 833 | index = logit[idx_time] |
| 834 | if index in char_center.keys(): |
| 835 | char_center[index][0] = ( |
| 836 | char_center[index][0] * char_center[index][1] + feat[idx_time] |
| 837 | ) / (char_center[index][1] + 1) |
| 838 | char_center[index][1] += 1 |
| 839 | else: |
| 840 | char_center[index] = [feat[idx_time], 1] |
| 841 | return char_center |
| 842 | |
| 843 | |
| 844 | def get_center(model, eval_dataloader, post_process_class): |
no outgoing calls
no test coverage detected
searching dependent graphs…