(
self,
value_threshold: float = 2,
including_act: bool = False,
act_multiplier: float = 0.5,
including_bias: bool = False,
bias_multiplier: float = 0.5,
method: EqualizationMethod = EqualizationMethod.ABSOLUTE_MAX)
| 353 | EqualizationHelper.scale_to_downstream(op, scale) |
| 354 | |
| 355 | def channel_split( |
| 356 | self, |
| 357 | value_threshold: float = 2, |
| 358 | including_act: bool = False, |
| 359 | act_multiplier: float = 0.5, |
| 360 | including_bias: bool = False, |
| 361 | bias_multiplier: float = 0.5, |
| 362 | method: EqualizationMethod = EqualizationMethod.ABSOLUTE_MAX): |
| 363 | # extract key value from pair |
| 364 | upstream_key_values, downstream_key_values = [], [] |
| 365 | for op in self.upstream_layers: |
| 366 | key_value = EqualizationHelper.key_value_from_upstream( |
| 367 | op=op, including_bias=including_bias, including_act=including_act, |
| 368 | bias_multiplier=bias_multiplier, act_multiplier=act_multiplier) |
| 369 | upstream_key_values.append(key_value) |
| 370 | |
| 371 | for op in self.downstream_layers: |
| 372 | key_value = EqualizationHelper.key_value_from_downstream(op=op) |
| 373 | downstream_key_values.append(key_value) |
| 374 | |
| 375 | upstream_key_values = self.reduce_by_axis(upstream_key_values, method=method) |
| 376 | downstream_key_values = self.reduce_by_axis(downstream_key_values, method=method) |
| 377 | |
| 378 | mask = torch.logical_and(upstream_key_values >= value_threshold, downstream_key_values >= value_threshold) |
| 379 | |
| 380 | # write back all params |
| 381 | for op in self.upstream_layers: |
| 382 | ChannelSplitHelper.channel_split_upstream( |
| 383 | op = op, scale_factor = 1 / sqrt(2), mask = mask) |
| 384 | |
| 385 | for op in self.downstream_layers: |
| 386 | ChannelSplitHelper.channel_split_downstream( |
| 387 | op = op, scale_factor = 1 / sqrt(2), mask = mask) |
| 388 | |
| 389 | def calculate_scale( |
| 390 | self, upstream_key_values: torch.Tensor, |
no test coverage detected