| 264 | return x + self.ffnn(self.ffnn_norm(x)) |
| 265 | |
| 266 | class Block(nn.Module): |
| 267 | def __init__(self, |
| 268 | dim: int, |
| 269 | layer_id: int = 0, |
| 270 | n_head: int = 16, |
| 271 | kv_heads: Optional[int] = None, |
| 272 | ff_dim: Optional[int] = None, |
| 273 | eps: float = 1e-5, |
| 274 | causal: bool = True, |
| 275 | shape_rotator: ShapeRotator = None): |
| 276 | super().__init__() |
| 277 | self.attn = PreNormAttn(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal) |
| 278 | self.ffnn = PreNormFFNN(dim, ff_dim, eps=eps) |
| 279 | self.dim = dim |
| 280 | self.layer_id = layer_id |
| 281 | self.head_dim = dim // n_head |
| 282 | self.expand_dim = self.ffnn.ffnn.expand_dim |
| 283 | |
| 284 | self.reset_parameters() |
| 285 | |
| 286 | def reset_parameters(self): |
| 287 | std = 1.0 / math.sqrt(self.dim) |
| 288 | nn.init.trunc_normal_(self.ffnn.ffnn.gateup_proj.weight, std=std, a=-3 * std, b=3 * std) |
| 289 | nn.init.trunc_normal_(self.attn.attn.proj_qkv.weight, std=std, a=-3 * std, b=3 * std) |
| 290 | nn.init.trunc_normal_(self.attn.attn.attn_out.weight, std=std, a=-3 * std, b=3 * std) |
| 291 | |
| 292 | xstd = 1.0 / math.sqrt(self.expand_dim) |
| 293 | nn.init.trunc_normal_(self.ffnn.ffnn.down_proj.weight, std=xstd, a=-3 * xstd, b=3 * xstd) |
| 294 | |
| 295 | def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor: |
| 296 | """ |
| 297 | x: (B, S, D) |
| 298 | kv: (B, S, H, D) |
| 299 | """ |
| 300 | h = self.attn(x, kv) |
| 301 | out = self.ffnn(h) |
| 302 | return out |
| 303 | |
| 304 | |
| 305 | |