MCPcopy
hub / github.com/THUDM/LongWriter / forward

Method forward

train/patch/modeling_llama.py:243–263  ·  view source on GitHub ↗
(self, x)

Source from the content-addressed store, hash-verified

241 self.act_fn = ACT2FN[config.hidden_act]
242
243 def forward(self, x):
244 if self.config.pretraining_tp > 1:
245 slice = self.intermediate_size // self.config.pretraining_tp
246 gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
247 up_proj_slices = self.up_proj.weight.split(slice, dim=0)
248 down_proj_slices = self.down_proj.weight.split(slice, dim=1)
249
250 gate_proj = torch.cat(
251 [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
252 )
253 up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
254
255 intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
256 down_proj = [
257 F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
258 ]
259 down_proj = sum(down_proj)
260 else:
261 down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
262
263 return down_proj
264
265
266def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:

Callers 1

forwardMethod · 0.45

Calls

no outgoing calls

Tested by

no test coverage detected