MCPcopy
hub / github.com/NVlabs/imaginaire / apply

Method apply

imaginaire/layers/weight_norm.py:93–147  ·  view source on GitHub ↗
(module, weight_name, bias_name, lr_mul, equalized)

Source from the content-addressed store, hash-verified

91
92 @staticmethod
93 def apply(module, weight_name, bias_name, lr_mul, equalized):
94 assert weight_name == 'weight'
95 assert bias_name == 'bias'
96 fn = ScaledLR(weight_name, bias_name)
97 module.register_forward_pre_hook(fn)
98
99 if hasattr(module, bias_name):
100 # module.bias is a parameter (can be None).
101 bias = getattr(module, bias_name)
102 delattr(module, bias_name)
103 module.register_parameter(bias_name + '_ori', bias)
104 else:
105 # module.bias does not exist.
106 bias = None
107 setattr(module, bias_name + '_ori', bias)
108 if bias is not None:
109 setattr(module, bias_name, bias.data)
110 else:
111 setattr(module, bias_name, None)
112 module.register_buffer('bias_scale', torch.tensor(lr_mul))
113
114 if hasattr(module, weight_name + '_orig'):
115 # The module has been wrapped with spectral normalization.
116 # We only want to keep a single weight parameter.
117 weight = getattr(module, weight_name + '_orig')
118 delattr(module, weight_name + '_orig')
119 module.register_parameter(weight_name + '_ori', weight)
120 setattr(module, weight_name + '_orig', weight.data)
121 # Put this hook before the spectral norm hook.
122 module._forward_pre_hooks = collections.OrderedDict(
123 reversed(list(module._forward_pre_hooks.items()))
124 )
125 module.use_sn = True
126 else:
127 weight = getattr(module, weight_name)
128 delattr(module, weight_name)
129 module.register_parameter(weight_name + '_ori', weight)
130 setattr(module, weight_name, weight.data)
131 module.use_sn = False
132
133 # assert weight.dim() == 4 or weight.dim() == 2
134 if equalized:
135 fan_in = weight.data.size(1) * weight.data[0][0].numel()
136 # Theoretically, the gain should be sqrt(2) instead of 1.
137 # The official StyleGAN2 uses 1 for some reason.
138 module.register_buffer(
139 'weight_scale', torch.tensor(lr_mul * ((1 / fan_in) ** 0.5))
140 )
141 else:
142 module.register_buffer('weight_scale', torch.tensor(lr_mul))
143
144 module.lr_mul = module.weight_scale
145 module.base_lr_mul = lr_mul
146
147 return fn
148
149 def remove(self, module):
150 with torch.no_grad():

Callers 15

__init__Method · 0.80
positional_encodingFunction · 0.80
scaled_lrFunction · 0.80
init_temporal_networkMethod · 0.80
init_temporal_networkMethod · 0.80
custom_initMethod · 0.80
convert_weightsFunction · 0.80

Calls 1

ScaledLRClass · 0.85

Tested by

no test coverage detected