MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / forward

Method forward

tensorrt_llm/models/falcon/model.py:104–152  ·  view source on GitHub ↗
(self,
                hidden_states: Tensor,
                attention_mask=None,
                use_cache=False,
                kv_cache_params=None,
                attention_params=None)

Source from the content-addressed store, hash-verified

102 return self.new_decoder_architecture or self.parallel_attn
103
104 def forward(self,
105 hidden_states: Tensor,
106 attention_mask=None,
107 use_cache=False,
108 kv_cache_params=None,
109 attention_params=None):
110 assert isinstance(hidden_states, Tensor)
111
112 residual = hidden_states
113
114 if self.new_decoder_architecture and self.num_ln_in_parallel_attn == 2:
115 mlp_ln_output = self.mlp_layernorm(hidden_states)
116 hidden_states = self.input_layernorm(hidden_states)
117 input_ln_output = hidden_states
118 attention_output = self.attention(hidden_states,
119 attention_mask=attention_mask,
120 use_cache=use_cache,
121 kv_cache_params=kv_cache_params,
122 attention_params=attention_params)
123
124 if use_cache:
125 attention_output, presents = attention_output
126
127 if not self.new_decoder_architecture:
128 if self.parallel_attn:
129 hidden_states = input_ln_output
130 else:
131 hidden_states = residual + attention_output
132 residual = hidden_states
133 hidden_states = self.post_layernorm(hidden_states)
134 elif self.num_ln_in_parallel_attn == 2:
135 hidden_states = mlp_ln_output
136
137 if (self.new_decoder_architecture and self.parallel_attn
138 and self.num_ln_in_parallel_attn == 1):
139 hidden_states = input_ln_output
140
141 hidden_states = self.mlp(hidden_states)
142
143 if self.is_parallel_attention:
144 hidden_states = hidden_states + attention_output
145 if self.config.mapping.tp_size > 1:
146 hidden_states = allreduce(hidden_states,
147 self.config.mapping.tp_group)
148
149 hidden_states = residual + hidden_states
150 if use_cache:
151 return hidden_states, presents
152 return hidden_states
153
154
155class FalconModel(Module):

Callers

nothing calls this directly

Calls 1

allreduceFunction · 0.50

Tested by

no test coverage detected