(module, weight_name, bias_name, lr_mul, equalized)
| 91 | |
| 92 | @staticmethod |
| 93 | def apply(module, weight_name, bias_name, lr_mul, equalized): |
| 94 | assert weight_name == 'weight' |
| 95 | assert bias_name == 'bias' |
| 96 | fn = ScaledLR(weight_name, bias_name) |
| 97 | module.register_forward_pre_hook(fn) |
| 98 | |
| 99 | if hasattr(module, bias_name): |
| 100 | # module.bias is a parameter (can be None). |
| 101 | bias = getattr(module, bias_name) |
| 102 | delattr(module, bias_name) |
| 103 | module.register_parameter(bias_name + '_ori', bias) |
| 104 | else: |
| 105 | # module.bias does not exist. |
| 106 | bias = None |
| 107 | setattr(module, bias_name + '_ori', bias) |
| 108 | if bias is not None: |
| 109 | setattr(module, bias_name, bias.data) |
| 110 | else: |
| 111 | setattr(module, bias_name, None) |
| 112 | module.register_buffer('bias_scale', torch.tensor(lr_mul)) |
| 113 | |
| 114 | if hasattr(module, weight_name + '_orig'): |
| 115 | # The module has been wrapped with spectral normalization. |
| 116 | # We only want to keep a single weight parameter. |
| 117 | weight = getattr(module, weight_name + '_orig') |
| 118 | delattr(module, weight_name + '_orig') |
| 119 | module.register_parameter(weight_name + '_ori', weight) |
| 120 | setattr(module, weight_name + '_orig', weight.data) |
| 121 | # Put this hook before the spectral norm hook. |
| 122 | module._forward_pre_hooks = collections.OrderedDict( |
| 123 | reversed(list(module._forward_pre_hooks.items())) |
| 124 | ) |
| 125 | module.use_sn = True |
| 126 | else: |
| 127 | weight = getattr(module, weight_name) |
| 128 | delattr(module, weight_name) |
| 129 | module.register_parameter(weight_name + '_ori', weight) |
| 130 | setattr(module, weight_name, weight.data) |
| 131 | module.use_sn = False |
| 132 | |
| 133 | # assert weight.dim() == 4 or weight.dim() == 2 |
| 134 | if equalized: |
| 135 | fan_in = weight.data.size(1) * weight.data[0][0].numel() |
| 136 | # Theoretically, the gain should be sqrt(2) instead of 1. |
| 137 | # The official StyleGAN2 uses 1 for some reason. |
| 138 | module.register_buffer( |
| 139 | 'weight_scale', torch.tensor(lr_mul * ((1 / fan_in) ** 0.5)) |
| 140 | ) |
| 141 | else: |
| 142 | module.register_buffer('weight_scale', torch.tensor(lr_mul)) |
| 143 | |
| 144 | module.lr_mul = module.weight_scale |
| 145 | module.base_lr_mul = lr_mul |
| 146 | |
| 147 | return fn |
| 148 | |
| 149 | def remove(self, module): |
| 150 | with torch.no_grad(): |
no test coverage detected