MCPcopy
hub / github.com/OpenNMT/OpenNMT-py / forward

Method forward

onmt/modules/sparse_losses.py:11–28  ·  view source on GitHub ↗

input (FloatTensor): ``(n, num_classes)``. target (LongTensor): ``(n,)``, the indices of the target classes

(ctx, input, target)

Source from the content-addressed store, hash-verified

9 @staticmethod
10 @custom_fwd
11 def forward(ctx, input, target):
12 """
13 input (FloatTensor): ``(n, num_classes)``.
14 target (LongTensor): ``(n,)``, the indices of the target classes
15 """
16 input_batch, classes = input.size()
17
18 z_k = input.gather(1, target.unsqueeze(1)).squeeze()
19 tau_z, support_size = _threshold_and_support(input, dim=1)
20 support = input > tau_z
21 x = torch.where(
22 support, input**2 - tau_z**2, torch.tensor(0.0, device=input.device)
23 ).sum(dim=1)
24 ctx.save_for_backward(input, target, tau_z)
25 # clamping necessary because of numerical errors: loss should be lower
26 # bounded by zero, but negative values near zero are possible without
27 # the clamp
28 return torch.clamp(x / 2 - z_k + 0.5, min=0.0)
29
30 @staticmethod
31 @custom_bwd

Callers

nothing calls this directly

Calls 2

_threshold_and_supportFunction · 0.90
squeezeMethod · 0.80

Tested by

no test coverage detected