()
| 23 | |
| 24 | |
| 25 | def plot_activations(): |
| 26 | fig, axes = plt.subplots(2, 5, sharex=True, sharey=True) |
| 27 | fns = [ |
| 28 | Affine(), |
| 29 | Tanh(), |
| 30 | Sigmoid(), |
| 31 | ReLU(), |
| 32 | LeakyReLU(), |
| 33 | ELU(), |
| 34 | Exponential(), |
| 35 | SELU(), |
| 36 | HardSigmoid(), |
| 37 | SoftPlus(), |
| 38 | ] |
| 39 | |
| 40 | for ax, fn in zip(axes.flatten(), fns): |
| 41 | X = np.linspace(-3, 3, 100).astype(float).reshape(100, 1) |
| 42 | ax.plot(X, fn(X), label=r"$y$", alpha=1.0) |
| 43 | ax.plot(X, fn.grad(X), label=r"$\frac{dy}{dx}$", alpha=1.0) |
| 44 | ax.plot(X, fn.grad2(X), label=r"$\frac{d^2 y}{dx^2}$", alpha=1.0) |
| 45 | ax.hlines(0, -3, 3, lw=1, linestyles="dashed", color="k") |
| 46 | ax.vlines(0, -1.2, 1.2, lw=1, linestyles="dashed", color="k") |
| 47 | ax.set_ylim(-1.1, 1.1) |
| 48 | ax.set_xlim(-3, 3) |
| 49 | ax.set_xticks([]) |
| 50 | ax.set_yticks([-1, 0, 1]) |
| 51 | ax.xaxis.set_visible(False) |
| 52 | # ax.yaxis.set_visible(False) |
| 53 | ax.set_title("{}".format(fn)) |
| 54 | ax.legend(frameon=False) |
| 55 | sns.despine(left=True, bottom=True) |
| 56 | |
| 57 | fig.set_size_inches(10, 5) |
| 58 | plt.tight_layout() |
| 59 | plt.savefig("img/plot.png", dpi=300) |
| 60 | plt.close("all") |
| 61 | |
| 62 | |
| 63 | if __name__ == "__main__": |
no test coverage detected