MCPcopy
hub / github.com/appvision-ai/fast-bert / plot

Method plot

fast_bert/learner_cls.py:982–1038  ·  view source on GitHub ↗

Plots the learning rate range test. Arguments: skip_start (int, optional): number of batches to trim from the start. Default: 10. skip_end (int, optional): number of batches to trim from the start. Default: 5. log_lr (bool,

(self, skip_start=10, skip_end=5, log_lr=True, show_lr=None, ax=None)

Source from the content-addressed store, hash-verified

980 raise RuntimeError("Optimizer already has a scheduler attached to it")
981
982 def plot(self, skip_start=10, skip_end=5, log_lr=True, show_lr=None, ax=None):
983 """Plots the learning rate range test.
984 Arguments:
985 skip_start (int, optional): number of batches to trim from the start.
986 Default: 10.
987 skip_end (int, optional): number of batches to trim from the start.
988 Default: 5.
989 log_lr (bool, optional): True to plot the learning rate in a logarithmic
990 scale; otherwise, plotted in a linear scale. Default: True.
991 show_lr (float, optional): if set, adds a vertical line to visualize the
992 specified learning rate. Default: None.
993 ax (matplotlib.axes.Axes, optional): the plot is created in the specified
994 matplotlib axes object and the figure is not be shown. If `None`, then
995 the figure and axes object are created in this method and the figure is
996 shown . Default: None.
997 Returns:
998 The matplotlib.axes.Axes object that contains the plot.
999 """
1000
1001 if skip_start < 0:
1002 raise ValueError("skip_start cannot be negative")
1003 if skip_end < 0:
1004 raise ValueError("skip_end cannot be negative")
1005 if show_lr is not None and not isinstance(show_lr, float):
1006 raise ValueError("show_lr must be float")
1007
1008 # Get the data to plot from the history dictionary. Also, handle skip_end=0
1009 # properly so the behaviour is the expected
1010 lrs = self.history["lr"]
1011 losses = self.history["loss"]
1012 if skip_end == 0:
1013 lrs = lrs[skip_start:]
1014 losses = losses[skip_start:]
1015 else:
1016 lrs = lrs[skip_start:-skip_end]
1017 losses = losses[skip_start:-skip_end]
1018
1019 # Create the figure and axes object if axes was not already given
1020 fig = None
1021 if ax is None:
1022 fig, ax = plt.subplots()
1023
1024 # Plot loss as a function of the learning rate
1025 ax.plot(lrs, losses)
1026 if log_lr:
1027 ax.set_xscale("log")
1028 ax.set_xlabel("Learning rate")
1029 ax.set_ylabel("Loss")
1030
1031 if show_lr is not None:
1032 ax.axvline(x=show_lr, color="red")
1033
1034 # Show only if the figure was created internally
1035 if fig is not None:
1036 plt.show()
1037
1038 return ax
1039

Callers 1

lr_findMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected