MCPcopy
hub / github.com/OpenNMT/OpenNMT-py / backward

Method backward

onmt/modules/sparse_activations.py:67–76  ·  view source on GitHub ↗
(ctx, grad_output)

Source from the content-addressed store, hash-verified

65 @staticmethod
66 @custom_bwd
67 def backward(ctx, grad_output):
68 supp_size, output = ctx.saved_tensors
69 dim = ctx.dim
70 grad_input = grad_output.clone()
71 grad_input[output == 0] = 0
72
73 v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze()
74 v_hat = v_hat.unsqueeze(dim)
75 grad_input = torch.where(output != 0, grad_input - v_hat, grad_input)
76 return grad_input, None
77
78
79sparsemax = SparsemaxFunction.apply

Callers

nothing calls this directly

Calls 1

squeezeMethod · 0.80

Tested by

no test coverage detected