Init torch.Tensor Args: tensor: Tensor to be initialized. init_type: Init type, candidate can be found in InitType. low: The lower bound of the uniform distribution, useful when init_type is uniform. high: The upper bound of the uniform distribution,
(tensor, init_type=InitType.XAVIER_UNIFORM, low=0, high=1,
mean=0, std=1, activation_type=ActivationType.NONE,
fan_mode=FAN_MODE.FAN_IN, negative_slope=0)
| 61 | |
| 62 | |
| 63 | def init_tensor(tensor, init_type=InitType.XAVIER_UNIFORM, low=0, high=1, |
| 64 | mean=0, std=1, activation_type=ActivationType.NONE, |
| 65 | fan_mode=FAN_MODE.FAN_IN, negative_slope=0): |
| 66 | """Init torch.Tensor |
| 67 | Args: |
| 68 | tensor: Tensor to be initialized. |
| 69 | init_type: Init type, candidate can be found in InitType. |
| 70 | low: The lower bound of the uniform distribution, |
| 71 | useful when init_type is uniform. |
| 72 | high: The upper bound of the uniform distribution, |
| 73 | useful when init_type is uniform. |
| 74 | mean: The mean of the normal distribution, |
| 75 | useful when init_type is normal. |
| 76 | std: The standard deviation of the normal distribution, |
| 77 | useful when init_type is normal. |
| 78 | activation_type: For xavier and kaiming init, |
| 79 | coefficient is calculate according the activation_type. |
| 80 | fan_mode: For kaiming init, fan mode is needed |
| 81 | negative_slope: For kaiming init, |
| 82 | coefficient is calculate according the negative_slope. |
| 83 | Returns: |
| 84 | """ |
| 85 | if init_type == InitType.UNIFORM: |
| 86 | return torch.nn.init.uniform_(tensor, a=low, b=high) |
| 87 | elif init_type == InitType.NORMAL: |
| 88 | return torch.nn.init.normal_(tensor, mean=mean, std=std) |
| 89 | elif init_type == InitType.XAVIER_UNIFORM: |
| 90 | return torch.nn.init.xavier_uniform_( |
| 91 | tensor, gain=torch.nn.init.calculate_gain(activation_type)) |
| 92 | elif init_type == InitType.XAVIER_NORMAL: |
| 93 | return torch.nn.init.xavier_normal_( |
| 94 | tensor, gain=torch.nn.init.calculate_gain(activation_type)) |
| 95 | elif init_type == InitType.KAIMING_UNIFORM: |
| 96 | return torch.nn.init.kaiming_uniform_( |
| 97 | tensor, a=negative_slope, mode=fan_mode, |
| 98 | nonlinearity=activation_type) |
| 99 | elif init_type == InitType.KAIMING_NORMAL: |
| 100 | return torch.nn.init.kaiming_normal_( |
| 101 | tensor, a=negative_slope, mode=fan_mode, |
| 102 | nonlinearity=activation_type) |
| 103 | elif init_type == InitType.ORTHOGONAL: |
| 104 | return torch.nn.init.orthogonal_( |
| 105 | tensor, gain=torch.nn.init.calculate_gain(activation_type)) |
| 106 | else: |
| 107 | raise TypeError( |
| 108 | "Unsupported tensor init type: %s. Supported init type is: %s" % ( |
| 109 | init_type, InitType.str())) |
| 110 | |
| 111 | |
| 112 | class OptimizerType(Type): |