| 241 | self.act_fn = ACT2FN[config.hidden_act] |
| 242 | |
| 243 | def forward(self, x): |
| 244 | if self.config.pretraining_tp > 1: |
| 245 | slice = self.intermediate_size // self.config.pretraining_tp |
| 246 | gate_proj_slices = self.gate_proj.weight.split(slice, dim=0) |
| 247 | up_proj_slices = self.up_proj.weight.split(slice, dim=0) |
| 248 | down_proj_slices = self.down_proj.weight.split(slice, dim=1) |
| 249 | |
| 250 | gate_proj = torch.cat( |
| 251 | [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1 |
| 252 | ) |
| 253 | up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1) |
| 254 | |
| 255 | intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2) |
| 256 | down_proj = [ |
| 257 | F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp) |
| 258 | ] |
| 259 | down_proj = sum(down_proj) |
| 260 | else: |
| 261 | down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
| 262 | |
| 263 | return down_proj |
| 264 | |
| 265 | |
| 266 | def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: |