MCPcopy
hub / github.com/ddbourgin/numpy-ml / TorchLayerNormLayer

Class TorchLayerNormLayer

numpy_ml/tests/nn_torch_models.py:216–298  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

214
215
216class TorchLayerNormLayer(nn.Module):
217 def __init__(self, feat_dims, params, mode, epsilon=1e-5):
218 super(TorchLayerNormLayer, self).__init__()
219
220 self.layer1 = nn.LayerNorm(
221 normalized_shape=feat_dims, eps=epsilon, elementwise_affine=True
222 )
223
224 scaler = params["scaler"]
225 intercept = params["intercept"]
226
227 if mode == "2D":
228 scaler = np.moveaxis(scaler, [0, 1, 2], [-2, -1, -3])
229 intercept = np.moveaxis(intercept, [0, 1, 2], [-2, -1, -3])
230
231 assert scaler.shape == self.layer1.weight.shape
232 assert intercept.shape == self.layer1.bias.shape
233 self.layer1.weight = nn.Parameter(torch.FloatTensor(scaler))
234 self.layer1.bias = nn.Parameter(torch.FloatTensor(intercept))
235
236 def forward(self, X):
237 # (N, H, W, C) -> (N, C, H, W)
238 if X.ndim == 4:
239 X = np.moveaxis(X, [0, 1, 2, 3], [0, -2, -1, -3])
240
241 if not isinstance(X, torch.Tensor):
242 X = torchify(X)
243
244 self.X = X
245 self.Y = self.layer1(self.X)
246 self.Y.retain_grad()
247
248 def extract_grads(self, X, Y_true=None):
249 self.forward(X)
250
251 if isinstance(Y_true, np.ndarray):
252 Y_true = np.moveaxis(Y_true, [0, 1, 2, 3], [0, -2, -1, -3])
253 self.loss1 = (
254 0.5 * F.mse_loss(self.Y, torchify(Y_true), size_average=False).sum()
255 )
256 else:
257 self.loss1 = self.Y.sum()
258
259 self.loss1.backward()
260
261 X_np = self.X.detach().numpy()
262 Y_np = self.Y.detach().numpy()
263 dX_np = self.X.grad.numpy()
264 dY_np = self.Y.grad.numpy()
265 intercept_np = self.layer1.bias.detach().numpy()
266 scaler_np = self.layer1.weight.detach().numpy()
267 dIntercept_np = self.layer1.bias.grad.numpy()
268 dScaler_np = self.layer1.weight.grad.numpy()
269
270 if self.X.dim() == 4:
271 orig, X_swap = [0, 1, 2, 3], [0, -1, -3, -2]
272 orig_p, p_swap = [0, 1, 2], [-1, -3, -2]
273 if isinstance(Y_true, np.ndarray):

Callers 2

test_LayerNorm1DFunction · 0.85
test_LayerNorm2DFunction · 0.85

Calls

no outgoing calls

Tested by 2

test_LayerNorm1DFunction · 0.68
test_LayerNorm2DFunction · 0.68