MCPcopy
hub / github.com/ddbourgin/numpy-ml / TorchBatchNormLayer

Class TorchBatchNormLayer

numpy_ml/tests/nn_torch_models.py:138–213  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

136
137
138class TorchBatchNormLayer(nn.Module):
139 def __init__(self, n_in, params, mode, momentum=0.9, epsilon=1e-5):
140 super(TorchBatchNormLayer, self).__init__()
141
142 scaler = params["scaler"]
143 intercept = params["intercept"]
144
145 if mode == "1D":
146 self.layer1 = nn.BatchNorm1d(
147 num_features=n_in, momentum=1 - momentum, eps=epsilon, affine=True
148 )
149 elif mode == "2D":
150 self.layer1 = nn.BatchNorm2d(
151 num_features=n_in, momentum=1 - momentum, eps=epsilon, affine=True
152 )
153
154 self.layer1.weight = nn.Parameter(torch.FloatTensor(scaler))
155 self.layer1.bias = nn.Parameter(torch.FloatTensor(intercept))
156
157 def forward(self, X):
158 # (N, H, W, C) -> (N, C, H, W)
159 if X.ndim == 4:
160 X = np.moveaxis(X, [0, 1, 2, 3], [0, -2, -1, -3])
161
162 if not isinstance(X, torch.Tensor):
163 X = torchify(X)
164
165 self.X = X
166 self.Y = self.layer1(self.X)
167 self.Y.retain_grad()
168
169 def extract_grads(self, X, Y_true=None):
170 self.forward(X)
171
172 if isinstance(Y_true, np.ndarray):
173 Y_true = np.moveaxis(Y_true, [0, 1, 2, 3], [0, -2, -1, -3])
174 self.loss1 = (
175 0.5 * F.mse_loss(self.Y, torchify(Y_true), size_average=False).sum()
176 )
177 else:
178 self.loss1 = self.Y.sum()
179
180 self.loss1.backward()
181
182 X_np = self.X.detach().numpy()
183 Y_np = self.Y.detach().numpy()
184 dX_np = self.X.grad.numpy()
185 dY_np = self.Y.grad.numpy()
186
187 if self.X.dim() == 4:
188 orig, X_swap = [0, 1, 2, 3], [0, -1, -3, -2]
189 if isinstance(Y_true, np.ndarray):
190 Y_true = np.moveaxis(Y_true, orig, X_swap)
191 X_np = np.moveaxis(X_np, orig, X_swap)
192 Y_np = np.moveaxis(Y_np, orig, X_swap)
193 dX_np = np.moveaxis(dX_np, orig, X_swap)
194 dY_np = np.moveaxis(dY_np, orig, X_swap)
195

Callers 2

test_BatchNorm1DFunction · 0.85
test_BatchNorm2DFunction · 0.85

Calls

no outgoing calls

Tested by 2

test_BatchNorm1DFunction · 0.68
test_BatchNorm2DFunction · 0.68