MCPcopy Index your code
hub / github.com/InternLM/InternLM / FeedForward

Class FeedForward

internlm/model/linear.py:140–207  ·  view source on GitHub ↗

FeedForward. Args: in_features (int): size of each input sample hidden_features (int): size of hidden state of FFN out_features (int): size of each output sample process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `p

Source from the content-addressed store, hash-verified

138
139
140class FeedForward(nn.Module):
141 """
142 FeedForward.
143
144 Args:
145 in_features (int): size of each input sample
146 hidden_features (int): size of hidden state of FFN
147 out_features (int): size of each output sample
148 process_group (Optional[torch.distributed.ProcessGroup]): The group of the current device for `parallel_mode`.
149 bias (bool): Whether the bias is needed for linears. True by default. But it is typically set to False
150 in the config.
151 device (Optional[Union[str, torch.device]]): The device will be used.
152 dtype (Optional[torch.dtype]): The type of data.
153 multiple_of (int): For efficient training. Reset the size of hidden feature. 256 by default.
154 """
155
156 def __init__(
157 self,
158 in_features: int,
159 hidden_features: int,
160 out_features: int = None,
161 process_group: Optional[torch.distributed.ProcessGroup] = None,
162 bias: bool = True,
163 device: Optional[torch.device] = None,
164 dtype: Optional[torch.dtype] = None,
165 multiple_of: int = 256,
166 ):
167 super().__init__()
168
169 hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of)
170
171 self.w1 = ColumnParallelLinearTorch(
172 in_features,
173 hidden_features,
174 process_group,
175 bias,
176 sequence_parallel=gpc.config.parallel.sequence_parallel,
177 device=device,
178 dtype=dtype,
179 )
180 self.w2 = ColumnParallelLinearTorch(
181 in_features,
182 hidden_features,
183 process_group,
184 bias,
185 sequence_parallel=gpc.config.parallel.sequence_parallel,
186 device=device,
187 dtype=dtype,
188 )
189 self.w3 = RowParallelLinearTorch(
190 hidden_features,
191 out_features,
192 process_group,
193 bias=bias,
194 sequence_parallel=gpc.config.parallel.sequence_parallel,
195 device=device,
196 dtype=dtype,
197 )

Callers 1

__init__Method · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected