(
self, upstream_key_values: torch.Tensor,
downstream_key_values: torch.Tensor,
value_threshold: float, scale_clip_value: float = 10)
| 387 | op = op, scale_factor = 1 / sqrt(2), mask = mask) |
| 388 | |
| 389 | def calculate_scale( |
| 390 | self, upstream_key_values: torch.Tensor, |
| 391 | downstream_key_values: torch.Tensor, |
| 392 | value_threshold: float, scale_clip_value: float = 10): |
| 393 | scale = 1 / torch.sqrt(upstream_key_values / downstream_key_values) |
| 394 | scale = torch.clamp(scale, 1 / scale_clip_value, scale_clip_value) |
| 395 | scale[(upstream_key_values + downstream_key_values) < value_threshold] = 1 |
| 396 | return scale |
| 397 | |
| 398 | def reduce_by_axis( |
| 399 | self, |