MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / LoraRuntimeParams

Class LoraRuntimeParams

tensorrt_llm/layers/lora.py:25–46  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

23
24
25class LoraRuntimeParams(object):
26
27 def __init__(
28 self,
29 lora_ranks: List[Tensor] = None,
30 lora_weights_pointers: List[Tensor] = None,
31 host_request_types: Tensor = None,
32 host_context_lengths: Tensor = None,
33 max_encoder_context_length: Tensor = None,
34 host_encoder_input_lengths: Tensor = None,
35 weight_index: int = 0,
36 partial_lora_mask: Tensor = None,
37 ):
38
39 self.lora_ranks = lora_ranks
40 self.lora_weights_pointers = lora_weights_pointers
41 self.host_request_types = host_request_types
42 self.host_context_lengths = host_context_lengths
43 self.max_encoder_context_length = max_encoder_context_length
44 self.host_encoder_input_lengths = host_encoder_input_lengths
45 self.weight_index = weight_index
46 self.partial_lora_mask = partial_lora_mask # Partial LoRA for https://arxiv.org/abs/2401.16420
47
48
49class Lora(Module):

Callers 6

forwardMethod · 0.85
forwardMethod · 0.85
fc_gate_loraFunction · 0.85
fc_gate_doraFunction · 0.85
get_runtime_paramsMethod · 0.85

Calls

no outgoing calls

Tested by 1