MCPcopy
hub / github.com/FareedKhan-dev/train-llm-from-scratch / masked_mean

Function masked_mean

src/post_training/utils.py:116–121  ·  view source on GitHub ↗

Mean of ``values`` over positions where ``mask`` is truthy (safe if mask empty).

(values: torch.Tensor, mask: torch.Tensor)

Source from the content-addressed store, hash-verified

114# --- Masked reductions -------------------------------------------------------
115
116def masked_mean(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
117 """Mean of ``values`` over positions where ``mask`` is truthy (safe if mask empty)."""
118 mask = mask.to(values.dtype)
119 total = (values * mask).sum()
120 count = mask.sum().clamp(min=1.0)
121 return total / count
122
123
124def masked_mean_per_row(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:

Callers 7

grpo_lossFunction · 0.90
whitenFunction · 0.90
ppo_policy_lossFunction · 0.90
ppo_value_lossFunction · 0.90
approx_klFunction · 0.90
mainFunction · 0.90

Calls

no outgoing calls

Tested by 1