(y_pred: Tensor, y_true: Tensor, **kwargs)
| 132 | |
| 133 | |
| 134 | def confusion_matrix(y_pred: Tensor, y_true: Tensor, **kwargs): |
| 135 | try: |
| 136 | y_pred = np.argmax(y_pred.detach().cpu().numpy(), axis=1) |
| 137 | return sklearn_confusion_matrix( |
| 138 | y_true.detach().cpu().numpy(), y_pred, labels=kwargs.get("labels"), |
| 139 | ) |
| 140 | except Exception as e: |
| 141 | logger.error(e) |
| 142 |