| 240 | |
| 241 | |
| 242 | class WeightInitializer(object): |
| 243 | def __init__(self, act_fn_str, mode="glorot_uniform"): |
| 244 | """ |
| 245 | A factory for weight initializers. |
| 246 | |
| 247 | Parameters |
| 248 | ---------- |
| 249 | act_fn_str : str |
| 250 | The string representation for the layer activation function |
| 251 | mode : str (default: 'glorot_uniform') |
| 252 | The weight initialization strategy. Valid entries are {"he_normal", |
| 253 | "he_uniform", "glorot_normal", glorot_uniform", "std_normal", |
| 254 | "trunc_normal"} |
| 255 | """ |
| 256 | if mode not in [ |
| 257 | "he_normal", |
| 258 | "he_uniform", |
| 259 | "glorot_normal", |
| 260 | "glorot_uniform", |
| 261 | "std_normal", |
| 262 | "trunc_normal", |
| 263 | ]: |
| 264 | raise ValueError("Unrecognize initialization mode: {}".format(mode)) |
| 265 | |
| 266 | self.mode = mode |
| 267 | self.act_fn = act_fn_str |
| 268 | |
| 269 | if mode == "glorot_uniform": |
| 270 | self._fn = glorot_uniform |
| 271 | elif mode == "glorot_normal": |
| 272 | self._fn = glorot_normal |
| 273 | elif mode == "he_uniform": |
| 274 | self._fn = he_uniform |
| 275 | elif mode == "he_normal": |
| 276 | self._fn = he_normal |
| 277 | elif mode == "std_normal": |
| 278 | self._fn = np.random.randn |
| 279 | elif mode == "trunc_normal": |
| 280 | self._fn = partial(truncated_normal, mean=0, std=1) |
| 281 | |
| 282 | def __call__(self, weight_shape): |
| 283 | """Initialize weights according to the specified strategy""" |
| 284 | if "glorot" in self.mode: |
| 285 | gain = self._calc_glorot_gain() |
| 286 | W = self._fn(weight_shape, gain) |
| 287 | elif self.mode == "std_normal": |
| 288 | W = self._fn(*weight_shape) |
| 289 | else: |
| 290 | W = self._fn(weight_shape) |
| 291 | return W |
| 292 | |
| 293 | def _calc_glorot_gain(self): |
| 294 | """ |
| 295 | Values from: |
| 296 | https://pytorch.org/docs/stable/nn.html?#torch.nn.init.calculate_gain |
| 297 | """ |
| 298 | gain = 1.0 |
| 299 | act_str = self.act_fn.lower() |
no outgoing calls
no test coverage detected