(
self,
x,
weight,
gemm_plugin: Optional[str] = None,
low_latency_gemm_plugin: Optional[str] = None,
use_fp8: bool = False,
alpha: Optional[np.ndarray] = None,
lora_runtime_params: Optional[LoraRuntimeParams] = None,
lora_hidden_state: Optional[Tensor] = None,
**kwargs)
| 475 | return 1 |
| 476 | |
| 477 | def multiply_collect( |
| 478 | self, |
| 479 | x, |
| 480 | weight, |
| 481 | gemm_plugin: Optional[str] = None, |
| 482 | low_latency_gemm_plugin: Optional[str] = None, |
| 483 | use_fp8: bool = False, |
| 484 | alpha: Optional[np.ndarray] = None, |
| 485 | lora_runtime_params: Optional[LoraRuntimeParams] = None, |
| 486 | lora_hidden_state: Optional[Tensor] = None, |
| 487 | **kwargs): |
| 488 | |
| 489 | gemm_allreduce_plugin = default_net( |
| 490 | ).plugin_config.gemm_allreduce_plugin |
| 491 | if gemm_allreduce_plugin: |
| 492 | if lora_runtime_params != None or lora_hidden_state != None: |
| 493 | raise RuntimeError( |
| 494 | "gemm_allreduce_plugin not supported with lora.") |
| 495 | |
| 496 | output_dtype = self.dtype |
| 497 | if isinstance(output_dtype, str): |
| 498 | output_dtype = str_dtype_to_trt(output_dtype) |
| 499 | |
| 500 | x = gemm_allreduce( |
| 501 | a=x, |
| 502 | b=weight, |
| 503 | transa=False, # row-major |
| 504 | transb=True, # col-major |
| 505 | alpha=alpha, |
| 506 | group=self.tp_group, # ranks participating |
| 507 | output_dtype=output_dtype, |
| 508 | fp8_inputs_override=use_fp8) |
| 509 | |
| 510 | if self.bias is not None: |
| 511 | bias = cast(self.bias.value, x.dtype) |
| 512 | if self.is_expert: |
| 513 | x = x + bias / self.tp_size |
| 514 | else: |
| 515 | x = x + bias |
| 516 | return x |
| 517 | else: |
| 518 | return super().multiply_collect(x, weight, gemm_plugin, |
| 519 | low_latency_gemm_plugin, use_fp8, |
| 520 | alpha, lora_runtime_params, |
| 521 | lora_hidden_state, **kwargs) |
| 522 | |
| 523 | def collect_and_bias(self, x, **kwargs): |
| 524 | all_reduce_params: Optional[AllReduceParams] = kwargs.get( |
no test coverage detected