(weight)
| 185 | |
| 186 | # model.final_adaLN_modulation.1 -> norm_out.linear |
| 187 | def swap_scale_shift(weight): |
| 188 | shift, scale = weight.chunk(2, dim=0) |
| 189 | new_weight = torch.cat([scale, shift], dim=0) |
| 190 | return new_weight |
| 191 | |
| 192 | state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.weight"]) |
| 193 | state_dict["norm_out.linear.bias"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.bias"]) |
no outgoing calls
no test coverage detected
searching dependent graphs…