FeedForward. Args: in_features (int): size of each input sample hidden_features (int): size of hidden state of FFN out_features (int): size of each output sample process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `p
| 138 | |
| 139 | |
| 140 | class FeedForward(nn.Module): |
| 141 | """ |
| 142 | FeedForward. |
| 143 | |
| 144 | Args: |
| 145 | in_features (int): size of each input sample |
| 146 | hidden_features (int): size of hidden state of FFN |
| 147 | out_features (int): size of each output sample |
| 148 | process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`. |
| 149 | bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False |
| 150 | in the config. |
| 151 | device (Optional[Union[str, torch.device]]): The device will be used. |
| 152 | dtype (Optional[torch.dtype]): The type of data. |
| 153 | multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default. |
| 154 | """ |
| 155 | |
| 156 | def __init__( |
| 157 | self, |
| 158 | in_features: int, |
| 159 | hidden_features: int, |
| 160 | out_features: int = None, |
| 161 | process_group: Optional[torch.distributed.ProcessGroup] = None, |
| 162 | bias: bool = True, |
| 163 | device: Optional[torch.device] = None, |
| 164 | dtype: Optional[torch.dtype] = None, |
| 165 | multiple_of: int = 256, |
| 166 | ): |
| 167 | super().__init__() |
| 168 | |
| 169 | hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) |
| 170 | |
| 171 | self.w1 = ColumnParallelLinearTorch( |
| 172 | in_features, |
| 173 | hidden_features, |
| 174 | process_group, |
| 175 | bias, |
| 176 | sequence_parallel=gpc.config.parallel.sequence_parallel, |
| 177 | device=device, |
| 178 | dtype=dtype, |
| 179 | ) |
| 180 | self.w2 = ColumnParallelLinearTorch( |
| 181 | in_features, |
| 182 | hidden_features, |
| 183 | process_group, |
| 184 | bias, |
| 185 | sequence_parallel=gpc.config.parallel.sequence_parallel, |
| 186 | device=device, |
| 187 | dtype=dtype, |
| 188 | ) |
| 189 | self.w3 = RowParallelLinearTorch( |
| 190 | hidden_features, |
| 191 | out_features, |
| 192 | process_group, |
| 193 | bias=bias, |
| 194 | sequence_parallel=gpc.config.parallel.sequence_parallel, |
| 195 | device=device, |
| 196 | dtype=dtype, |
| 197 | ) |