Plot training labels including class histograms and box statistics. Args: boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height]. cls (np.ndarray): Class indices. names (dict, optional): Dictionary mapping class indices to class names. save_
(boxes, cls, names=(), save_dir=Path(""), on_plot=None)
| 603 | @TryExcept() |
| 604 | @plt_settings() |
| 605 | def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None): |
| 606 | """Plot training labels including class histograms and box statistics. |
| 607 | |
| 608 | Args: |
| 609 | boxes (np.ndarray): Bounding box coordinates in format [x, y, width, height]. |
| 610 | cls (np.ndarray): Class indices. |
| 611 | names (dict, optional): Dictionary mapping class indices to class names. |
| 612 | save_dir (Path, optional): Directory to save the plot. |
| 613 | on_plot (Callable, optional): Function to call after plot is saved. |
| 614 | """ |
| 615 | import matplotlib.pyplot as plt # scope for faster 'import ultralytics' |
| 616 | import polars |
| 617 | from matplotlib.colors import LinearSegmentedColormap |
| 618 | |
| 619 | # Plot dataset labels |
| 620 | LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ") |
| 621 | nc = int(cls.max() + 1) # number of classes |
| 622 | boxes = boxes[:1000000] # limit to 1M boxes |
| 623 | x = polars.DataFrame(boxes, schema=["x", "y", "width", "height"]) |
| 624 | |
| 625 | # Matplotlib labels |
| 626 | subplot_3_4_color = LinearSegmentedColormap.from_list("white_blue", ["white", "blue"]) |
| 627 | ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel() |
| 628 | y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8) |
| 629 | for i in range(nc): |
| 630 | y[2].patches[i].set_color([x / 255 for x in colors(i)]) |
| 631 | ax[0].set_ylabel("instances") |
| 632 | if 0 < len(names) < 30: |
| 633 | ax[0].set_xticks(range(len(names))) |
| 634 | ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10) |
| 635 | ax[0].bar_label(y[2]) |
| 636 | else: |
| 637 | ax[0].set_xlabel("classes") |
| 638 | boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000 |
| 639 | img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255) |
| 640 | for class_id, box in zip(cls[:500], boxes[:500]): |
| 641 | ImageDraw.Draw(img).rectangle(box.tolist(), width=1, outline=colors(class_id)) # plot |
| 642 | ax[1].imshow(img) |
| 643 | ax[1].axis("off") |
| 644 | |
| 645 | ax[2].hist2d(x["x"], x["y"], bins=50, cmap=subplot_3_4_color) |
| 646 | ax[2].set_xlabel("x") |
| 647 | ax[2].set_ylabel("y") |
| 648 | ax[3].hist2d(x["width"], x["height"], bins=50, cmap=subplot_3_4_color) |
| 649 | ax[3].set_xlabel("width") |
| 650 | ax[3].set_ylabel("height") |
| 651 | for a in {0, 1, 2, 3}: |
| 652 | for s in {"top", "right", "left", "bottom"}: |
| 653 | ax[a].spines[s].set_visible(False) |
| 654 | |
| 655 | fname = save_dir / "labels.jpg" |
| 656 | plt.savefig(fname, dpi=200) |
| 657 | plt.close() |
| 658 | if on_plot: |
| 659 | on_plot(fname) |
| 660 | |
| 661 | |
| 662 | def save_one_box( |