MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / multiply_collect

Method multiply_collect

tensorrt_llm/layers/linear.py:477–521  ·  view source on GitHub ↗
(
            self,
            x,
            weight,
            gemm_plugin: Optional[str] = None,
            low_latency_gemm_plugin: Optional[str] = None,
            use_fp8: bool = False,
            alpha: Optional[np.ndarray] = None,
            lora_runtime_params: Optional[LoraRuntimeParams] = None,
            lora_hidden_state: Optional[Tensor] = None,
            **kwargs)

Source from the content-addressed store, hash-verified

475 return 1
476
477 def multiply_collect(
478 self,
479 x,
480 weight,
481 gemm_plugin: Optional[str] = None,
482 low_latency_gemm_plugin: Optional[str] = None,
483 use_fp8: bool = False,
484 alpha: Optional[np.ndarray] = None,
485 lora_runtime_params: Optional[LoraRuntimeParams] = None,
486 lora_hidden_state: Optional[Tensor] = None,
487 **kwargs):
488
489 gemm_allreduce_plugin = default_net(
490 ).plugin_config.gemm_allreduce_plugin
491 if gemm_allreduce_plugin:
492 if lora_runtime_params != None or lora_hidden_state != None:
493 raise RuntimeError(
494 "gemm_allreduce_plugin not supported with lora.")
495
496 output_dtype = self.dtype
497 if isinstance(output_dtype, str):
498 output_dtype = str_dtype_to_trt(output_dtype)
499
500 x = gemm_allreduce(
501 a=x,
502 b=weight,
503 transa=False, # row-major
504 transb=True, # col-major
505 alpha=alpha,
506 group=self.tp_group, # ranks participating
507 output_dtype=output_dtype,
508 fp8_inputs_override=use_fp8)
509
510 if self.bias is not None:
511 bias = cast(self.bias.value, x.dtype)
512 if self.is_expert:
513 x = x + bias / self.tp_size
514 else:
515 x = x + bias
516 return x
517 else:
518 return super().multiply_collect(x, weight, gemm_plugin,
519 low_latency_gemm_plugin, use_fp8,
520 alpha, lora_runtime_params,
521 lora_hidden_state, **kwargs)
522
523 def collect_and_bias(self, x, **kwargs):
524 all_reduce_params: Optional[AllReduceParams] = kwargs.get(

Callers 4

forwardMethod · 0.45
forwardMethod · 0.45
forwardMethod · 0.45
forwardMethod · 0.45

Calls 4

default_netFunction · 0.85
str_dtype_to_trtFunction · 0.85
gemm_allreduceFunction · 0.85
castFunction · 0.85

Tested by

no test coverage detected