(
self,
params: List[torch.Tensor],
method: EqualizationMethod,
axis: int=1,
)
| 396 | return scale |
| 397 | |
| 398 | def reduce_by_axis( |
| 399 | self, |
| 400 | params: List[torch.Tensor], |
| 401 | method: EqualizationMethod, |
| 402 | axis: int=1, |
| 403 | ) -> torch.Tensor: |
| 404 | params = torch.cat(params, axis=axis) |
| 405 | if method is EqualizationMethod.ABSOLUTE_MAX: |
| 406 | return torch.max(torch.abs(params), axis=axis)[0] |
| 407 | |
| 408 | elif method is EqualizationMethod.ABSOLUTE_MEAN: |
| 409 | return torch.mean(torch.abs(params), axis=axis) |
| 410 | |
| 411 | elif method is EqualizationMethod.SQUARE_MAX: |
| 412 | return torch.max(torch.square(params), axis=axis)[0] |
| 413 | |
| 414 | elif method is EqualizationMethod.SQUARE_MEAN: |
| 415 | return torch.mean(torch.square(params), axis=axis) |
| 416 | |
| 417 | else: |
| 418 | raise NotImplementedError('Equalization method %s is not support.' % str(method)) |
no outgoing calls
no test coverage detected