MCPcopy Index your code
hub / github.com/deepspeedai/DeepSpeedExamples / LinearActivation

Class LinearActivation

bing_bert/nvidia/modeling.py:151–196  ·  view source on GitHub ↗

r"""Fused Linear and activation Module.

Source from the content-addressed store, hash-verified

149ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
150
151class LinearActivation(Module):
152 r"""Fused Linear and activation Module.
153 """
154 __constants__ = ['bias']
155
156 def __init__(self, in_features, out_features, act='gelu', bias=True):
157 super(LinearActivation, self).__init__()
158 self.in_features = in_features
159 self.out_features = out_features
160 self.fused_gelu = False
161 self.fused_tanh = False
162 if isinstance(act, str) or (sys.version_info[0] == 2 and isinstance(act, unicode)):
163 if bias and act == 'gelu':
164 self.fused_gelu = True
165 elif bias and act == 'tanh':
166 self.fused_tanh = True
167 else:
168 self.act_fn = ACT2FN[act]
169 else:
170 self.act_fn = act
171 self.weight = Parameter(torch.Tensor(out_features, in_features))
172 if bias:
173 self.bias = Parameter(torch.Tensor(out_features))
174 else:
175 self.register_parameter('bias', None)
176 self.reset_parameters()
177
178 def reset_parameters(self):
179 init.kaiming_uniform_(self.weight, a=math.sqrt(5))
180 if self.bias is not None:
181 fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
182 bound = 1 / math.sqrt(fan_in)
183 init.uniform_(self.bias, -bound, bound)
184
185 def forward(self, input):
186 if self.fused_gelu:
187 return bias_gelu(self.bias, F.linear(input, self.weight, None))
188 elif self.fused_tanh:
189 return bias_tanh(self.bias, F.linear(input, self.weight, None))
190 else:
191 return self.act_fn(F.linear(input, self.weight, self.bias))
192
193 def extra_repr(self):
194 return 'in_features={}, out_features={}, bias={}'.format(
195 self.in_features, self.out_features, self.bias is not None
196 )
197
198
199class BertConfig(object):

Callers 3

__init__Method · 0.70
__init__Method · 0.70
__init__Method · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected