(self, normalize=True, save_dir='', names=())
| 188 | |
| 189 | @TryExcept('WARNING ⚠️ ConfusionMatrix plot failure') |
| 190 | def plot(self, normalize=True, save_dir='', names=()): |
| 191 | import seaborn as sn |
| 192 | |
| 193 | array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns |
| 194 | array[array < 0.005] = np.nan # don't annotate (would appear as 0.00) |
| 195 | |
| 196 | fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True) |
| 197 | nc, nn = self.nc, len(names) # number of classes, names |
| 198 | sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size |
| 199 | labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels |
| 200 | ticklabels = (names + ['background']) if labels else "auto" |
| 201 | with warnings.catch_warnings(): |
| 202 | warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered |
| 203 | sn.heatmap(array, |
| 204 | ax=ax, |
| 205 | annot=nc < 30, |
| 206 | annot_kws={ |
| 207 | "size": 8}, |
| 208 | cmap='Blues', |
| 209 | fmt='.2f', |
| 210 | square=True, |
| 211 | vmin=0.0, |
| 212 | xticklabels=ticklabels, |
| 213 | yticklabels=ticklabels).set_facecolor((1, 1, 1)) |
| 214 | ax.set_ylabel('True') |
| 215 | ax.set_ylabel('Predicted') |
| 216 | ax.set_title('Confusion Matrix') |
| 217 | fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250) |
| 218 | plt.close(fig) |
| 219 | |
| 220 | def print(self): |
| 221 | for i in range(self.nc + 1): |
no outgoing calls
no test coverage detected