MCPcopy
hub / github.com/ddbourgin/numpy-ml / WeightInitializer

Class WeightInitializer

numpy_ml/neural_nets/initializers/initializers.py:242–308  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

240
241
242class WeightInitializer(object):
243 def __init__(self, act_fn_str, mode="glorot_uniform"):
244 """
245 A factory for weight initializers.
246
247 Parameters
248 ----------
249 act_fn_str : str
250 The string representation for the layer activation function
251 mode : str (default: 'glorot_uniform')
252 The weight initialization strategy. Valid entries are {"he_normal",
253 "he_uniform", "glorot_normal", glorot_uniform", "std_normal",
254 "trunc_normal"}
255 """
256 if mode not in [
257 "he_normal",
258 "he_uniform",
259 "glorot_normal",
260 "glorot_uniform",
261 "std_normal",
262 "trunc_normal",
263 ]:
264 raise ValueError("Unrecognize initialization mode: {}".format(mode))
265
266 self.mode = mode
267 self.act_fn = act_fn_str
268
269 if mode == "glorot_uniform":
270 self._fn = glorot_uniform
271 elif mode == "glorot_normal":
272 self._fn = glorot_normal
273 elif mode == "he_uniform":
274 self._fn = he_uniform
275 elif mode == "he_normal":
276 self._fn = he_normal
277 elif mode == "std_normal":
278 self._fn = np.random.randn
279 elif mode == "trunc_normal":
280 self._fn = partial(truncated_normal, mean=0, std=1)
281
282 def __call__(self, weight_shape):
283 """Initialize weights according to the specified strategy"""
284 if "glorot" in self.mode:
285 gain = self._calc_glorot_gain()
286 W = self._fn(weight_shape, gain)
287 elif self.mode == "std_normal":
288 W = self._fn(*weight_shape)
289 else:
290 W = self._fn(weight_shape)
291 return W
292
293 def _calc_glorot_gain(self):
294 """
295 Values from:
296 https://pytorch.org/docs/stable/nn.html?#torch.nn.init.calculate_gain
297 """
298 gain = 1.0
299 act_str = self.act_fn.lower()

Callers 10

_init_paramsMethod · 0.85
_init_paramsMethod · 0.85
_init_paramsMethod · 0.85
_init_paramsMethod · 0.85
_init_paramsMethod · 0.85
_init_paramsMethod · 0.85
_init_paramsMethod · 0.85
_init_paramsMethod · 0.85
_init_paramsMethod · 0.85
_init_paramsMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected