(x, brange, residual, residual_scale_factor, scaling_vector=None)
| 134 | |
| 135 | |
| 136 | def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): |
| 137 | if scaling_vector is None: |
| 138 | x_flat = x.flatten(1) |
| 139 | residual = residual.flatten(1) |
| 140 | x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) |
| 141 | else: |
| 142 | x_plus_residual = scaled_index_add( |
| 143 | x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor |
| 144 | ) |
| 145 | return x_plus_residual |
| 146 | |
| 147 | |
| 148 | attn_bias_cache: Dict[Tuple, Any] = {} |
no test coverage detected