MCPcopy Index your code
hub / github.com/Standard-Intelligence/hertz-dev / Block

Class Block

transformer.py:266–302  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

264 return x + self.ffnn(self.ffnn_norm(x))
265
266class Block(nn.Module):
267 def __init__(self,
268 dim: int,
269 layer_id: int = 0,
270 n_head: int = 16,
271 kv_heads: Optional[int] = None,
272 ff_dim: Optional[int] = None,
273 eps: float = 1e-5,
274 causal: bool = True,
275 shape_rotator: ShapeRotator = None):
276 super().__init__()
277 self.attn = PreNormAttn(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal)
278 self.ffnn = PreNormFFNN(dim, ff_dim, eps=eps)
279 self.dim = dim
280 self.layer_id = layer_id
281 self.head_dim = dim // n_head
282 self.expand_dim = self.ffnn.ffnn.expand_dim
283
284 self.reset_parameters()
285
286 def reset_parameters(self):
287 std = 1.0 / math.sqrt(self.dim)
288 nn.init.trunc_normal_(self.ffnn.ffnn.gateup_proj.weight, std=std, a=-3 * std, b=3 * std)
289 nn.init.trunc_normal_(self.attn.attn.proj_qkv.weight, std=std, a=-3 * std, b=3 * std)
290 nn.init.trunc_normal_(self.attn.attn.attn_out.weight, std=std, a=-3 * std, b=3 * std)
291
292 xstd = 1.0 / math.sqrt(self.expand_dim)
293 nn.init.trunc_normal_(self.ffnn.ffnn.down_proj.weight, std=xstd, a=-3 * xstd, b=3 * xstd)
294
295 def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor:
296 """
297 x: (B, S, D)
298 kv: (B, S, H, D)
299 """
300 h = self.attn(x, kv)
301 out = self.ffnn(h)
302 return out
303
304
305

Callers 1

__init__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected