(
self,
hidden_size: int = 768,
num_attention_heads: int = 12,
mlp_ratio: int = 4,
attn_drop_rate: float = 0,
drop_rate: float = 0.0,
dtype: torch.dtype = torch.float,
layer_norm_epsilon: float = 1e-6,
checkpoint: bool = False,
layer_idx: int = 0,
residual_in_fp32: bool = False,
device: Optional[torch.device] = None,
norm_type: str = "rmsnorm",
dropout_selective_checkpoint: bool = True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
)
| 53 | """ |
| 54 | |
| 55 | def __init__( |
| 56 | self, |
| 57 | hidden_size: int = 768, |
| 58 | num_attention_heads: int = 12, |
| 59 | mlp_ratio: int = 4, |
| 60 | attn_drop_rate: float = 0, |
| 61 | drop_rate: float = 0.0, |
| 62 | dtype: torch.dtype = torch.float, |
| 63 | layer_norm_epsilon: float = 1e-6, |
| 64 | checkpoint: bool = False, |
| 65 | layer_idx: int = 0, |
| 66 | residual_in_fp32: bool = False, |
| 67 | device: Optional[torch.device] = None, |
| 68 | norm_type: str = "rmsnorm", |
| 69 | dropout_selective_checkpoint: bool = True, |
| 70 | use_scaled_init: bool = True, |
| 71 | use_swiglu: bool = True, |
| 72 | use_flash_attn: bool = True, |
| 73 | ): |
| 74 | super().__init__() |
| 75 | self.checkpoint = checkpoint |
| 76 | # dropout selective checkpoint can only be enabled when checkpoint is disabled. |
| 77 | self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False |
| 78 | self.layer_idx = layer_idx |
| 79 | self.use_flash_attn = use_flash_attn |
| 80 | |
| 81 | head_dim = hidden_size // num_attention_heads |
| 82 | self.mixer = MHA( |
| 83 | embed_dim=hidden_size, |
| 84 | num_heads=num_attention_heads, |
| 85 | process_group=gpc.get_group(ParallelMode.TENSOR), |
| 86 | dropout=attn_drop_rate, |
| 87 | softmax_scale=1 / math.sqrt(head_dim), |
| 88 | causal=True, |
| 89 | layer_idx=layer_idx, |
| 90 | rotary_emb_dim=head_dim, |
| 91 | rotary_emb_scale_base=0, |
| 92 | use_flash_attn=use_flash_attn, |
| 93 | device=device, |
| 94 | dtype=dtype, |
| 95 | ) |
| 96 | |
| 97 | self.dropout1 = nn.Dropout(drop_rate) |
| 98 | if norm_type == "rmsnorm": |
| 99 | self.norm1 = RMSNorm(hidden_size, eps=layer_norm_epsilon) |
| 100 | self.norm2 = RMSNorm(hidden_size, eps=layer_norm_epsilon) |
| 101 | else: |
| 102 | self.norm1 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) |
| 103 | self.norm2 = nn.LayerNorm(hidden_size, eps=layer_norm_epsilon) |
| 104 | |
| 105 | if use_swiglu: |
| 106 | self.mlp = FeedForward( |
| 107 | hidden_size, |
| 108 | int(hidden_size * mlp_ratio), |
| 109 | out_features=hidden_size, |
| 110 | process_group=gpc.get_group(ParallelMode.TENSOR), |
| 111 | bias=False, |
| 112 | device=device, |
nothing calls this directly
no test coverage detected