Mean of ``values`` over positions where ``mask`` is truthy (safe if mask empty).
(values: torch.Tensor, mask: torch.Tensor)
| 114 | # --- Masked reductions ------------------------------------------------------- |
| 115 | |
| 116 | def masked_mean(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| 117 | """Mean of ``values`` over positions where ``mask`` is truthy (safe if mask empty).""" |
| 118 | mask = mask.to(values.dtype) |
| 119 | total = (values * mask).sum() |
| 120 | count = mask.sum().clamp(min=1.0) |
| 121 | return total / count |
| 122 | |
| 123 | |
| 124 | def masked_mean_per_row(values: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
no outgoing calls