MCPcopy Index your code
hub / github.com/TheAlgorithms/Python / DenseLayer

Class DenseLayer

neural_network/back_propagation_neural_network.py:29–96  ·  view source on GitHub ↗

Layers of BP neural network

Source from the content-addressed store, hash-verified

27
28
29class DenseLayer:
30 """
31 Layers of BP neural network
32 """
33
34 def __init__(
35 self, units, activation=None, learning_rate=None, is_input_layer=False
36 ):
37 """
38 common connected layer of bp network
39 :param units: numbers of neural units
40 :param activation: activation function
41 :param learning_rate: learning rate for paras
42 :param is_input_layer: whether it is input layer or not
43 """
44 self.units = units
45 self.weight = None
46 self.bias = None
47 self.activation = activation
48 if learning_rate is None:
49 learning_rate = 0.3
50 self.learn_rate = learning_rate
51 self.is_input_layer = is_input_layer
52
53 def initializer(self, back_units):
54 rng = np.random.default_rng()
55 self.weight = np.asmatrix(rng.normal(0, 0.5, (self.units, back_units)))
56 self.bias = np.asmatrix(rng.normal(0, 0.5, self.units)).T
57 if self.activation is None:
58 self.activation = sigmoid
59
60 def cal_gradient(self):
61 # activation function may be sigmoid or linear
62 if self.activation == sigmoid:
63 gradient_mat = np.dot(self.output, (1 - self.output).T)
64 gradient_activation = np.diag(np.diag(gradient_mat))
65 else:
66 gradient_activation = 1
67 return gradient_activation
68
69 def forward_propagation(self, xdata):
70 self.xdata = xdata
71 if self.is_input_layer:
72 # input layer
73 self.wx_plus_b = xdata
74 self.output = xdata
75 return xdata
76 else:
77 self.wx_plus_b = np.dot(self.weight, self.xdata) - self.bias
78 self.output = self.activation(self.wx_plus_b)
79 return self.output
80
81 def back_propagation(self, gradient):
82 gradient_activation = self.cal_gradient() # i * i 维
83 gradient = np.asmatrix(np.dot(gradient.T, gradient_activation))
84
85 self._gradient_weight = np.asmatrix(self.xdata)
86 self._gradient_bias = -1

Callers 1

exampleFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected