(rm1: Operation, rm2: Operation,
eps: Operation, scale: torch.Tensor,
bias: torch.Tensor, layernorm_input_var: Variable,
layernorm_output_var: Variable)
| 694 | """Fuse Layernormalization with pattern matching.""" |
| 695 | |
| 696 | def _fuse(rm1: Operation, rm2: Operation, |
| 697 | eps: Operation, scale: torch.Tensor, |
| 698 | bias: torch.Tensor, layernorm_input_var: Variable, |
| 699 | layernorm_output_var: Variable) -> Operation: |
| 700 | |
| 701 | if rm2.type == rm1.type == 'ReduceMean': |
| 702 | if 'axes' not in rm1.attributes: return None |
| 703 | if 'axes' not in rm2.attributes: return None |
| 704 | if rm1.attributes['axes'] != rm2.attributes['axes']: return None |
| 705 | layernorm_axis = rm1.attributes['axes'] |
| 706 | if isinstance(layernorm_axis, list): |
| 707 | if len(layernorm_axis) != 1: return None |
| 708 | layernorm_axis = layernorm_axis[0] |
| 709 | if not isinstance(layernorm_axis, int): return None |
| 710 | else: layernorm_axis = -1 |
| 711 | |
| 712 | if not eps.inputs[-1].is_parameter: return None |
| 713 | value = eps.inputs[-1].value |
| 714 | value = convert_any_to_torch_tensor(value).cpu() |
| 715 | if value.numel() != 1: return None |
| 716 | layernorm_eps = value.item() |
| 717 | |
| 718 | layernorm_output_var.source_op.outputs.clear() |
| 719 | layernorm = self.graph.create_operation( |
| 720 | op_type = 'LayerNormalization', |
| 721 | attributes = {'axis': layernorm_axis, 'epsilon': layernorm_eps, 'stash_type': 0}, |
| 722 | inputs = [layernorm_input_var, self.graph.create_variable(value=scale, is_parameter=True)], |
| 723 | outputs = [layernorm_output_var]) |
| 724 | |
| 725 | if bias is not None: |
| 726 | self.graph.create_link_with_op( |
| 727 | variable=self.graph.create_variable(value=bias, is_parameter=True), |
| 728 | A=None, B=layernorm) |
| 729 | return layernorm |
| 730 | |
| 731 | search_engine = SearchableGraph(graph=self.graph) |
| 732 | fused = False |
nothing calls this directly
no test coverage detected