(args)
| 104 | |
| 105 | @torch.no_grad() |
| 106 | def populate_state_dict(args): |
| 107 | original_state_dict = load_original_state_dict(args) |
| 108 | state_dict_keys = list(original_state_dict.keys()) |
| 109 | mmdit_layers = calculate_layers(state_dict_keys, key_prefix="double_layers") |
| 110 | single_dit_layers = calculate_layers(state_dict_keys, key_prefix="single_layers") |
| 111 | |
| 112 | converted_state_dict = convert_transformer(original_state_dict) |
| 113 | model_diffusers = AuraFlowTransformer2DModel( |
| 114 | num_mmdit_layers=mmdit_layers, num_single_dit_layers=single_dit_layers |
| 115 | ) |
| 116 | model_diffusers.load_state_dict(converted_state_dict, strict=True) |
| 117 | |
| 118 | return model_diffusers |
| 119 | |
| 120 | |
| 121 | if __name__ == "__main__": |
no test coverage detected
searching dependent graphs…