Plots the learning rate range test. Arguments: skip_start (int, optional): number of batches to trim from the start. Default: 10. skip_end (int, optional): number of batches to trim from the start. Default: 5. log_lr (bool,
(self, skip_start=10, skip_end=5, log_lr=True, show_lr=None, ax=None)
| 980 | raise RuntimeError("Optimizer already has a scheduler attached to it") |
| 981 | |
| 982 | def plot(self, skip_start=10, skip_end=5, log_lr=True, show_lr=None, ax=None): |
| 983 | """Plots the learning rate range test. |
| 984 | Arguments: |
| 985 | skip_start (int, optional): number of batches to trim from the start. |
| 986 | Default: 10. |
| 987 | skip_end (int, optional): number of batches to trim from the start. |
| 988 | Default: 5. |
| 989 | log_lr (bool, optional): True to plot the learning rate in a logarithmic |
| 990 | scale; otherwise, plotted in a linear scale. Default: True. |
| 991 | show_lr (float, optional): if set, adds a vertical line to visualize the |
| 992 | specified learning rate. Default: None. |
| 993 | ax (matplotlib.axes.Axes, optional): the plot is created in the specified |
| 994 | matplotlib axes object and the figure is not be shown. If `None`, then |
| 995 | the figure and axes object are created in this method and the figure is |
| 996 | shown . Default: None. |
| 997 | Returns: |
| 998 | The matplotlib.axes.Axes object that contains the plot. |
| 999 | """ |
| 1000 | |
| 1001 | if skip_start < 0: |
| 1002 | raise ValueError("skip_start cannot be negative") |
| 1003 | if skip_end < 0: |
| 1004 | raise ValueError("skip_end cannot be negative") |
| 1005 | if show_lr is not None and not isinstance(show_lr, float): |
| 1006 | raise ValueError("show_lr must be float") |
| 1007 | |
| 1008 | # Get the data to plot from the history dictionary. Also, handle skip_end=0 |
| 1009 | # properly so the behaviour is the expected |
| 1010 | lrs = self.history["lr"] |
| 1011 | losses = self.history["loss"] |
| 1012 | if skip_end == 0: |
| 1013 | lrs = lrs[skip_start:] |
| 1014 | losses = losses[skip_start:] |
| 1015 | else: |
| 1016 | lrs = lrs[skip_start:-skip_end] |
| 1017 | losses = losses[skip_start:-skip_end] |
| 1018 | |
| 1019 | # Create the figure and axes object if axes was not already given |
| 1020 | fig = None |
| 1021 | if ax is None: |
| 1022 | fig, ax = plt.subplots() |
| 1023 | |
| 1024 | # Plot loss as a function of the learning rate |
| 1025 | ax.plot(lrs, losses) |
| 1026 | if log_lr: |
| 1027 | ax.set_xscale("log") |
| 1028 | ax.set_xlabel("Learning rate") |
| 1029 | ax.set_ylabel("Loss") |
| 1030 | |
| 1031 | if show_lr is not None: |
| 1032 | ax.axvline(x=show_lr, color="red") |
| 1033 | |
| 1034 | # Show only if the figure was created internally |
| 1035 | if fig is not None: |
| 1036 | plt.show() |
| 1037 | |
| 1038 | return ax |
| 1039 |