(orig_path, unet_diffusers_config)
| 1117 | |
| 1118 | |
| 1119 | def verify_param_count(orig_path, unet_diffusers_config): |
| 1120 | if "-II-" in orig_path: |
| 1121 | from deepfloyd_if.modules import IFStageII |
| 1122 | |
| 1123 | if_II = IFStageII(device="cpu", dir_or_name=orig_path) |
| 1124 | elif "-III-" in orig_path: |
| 1125 | from deepfloyd_if.modules import IFStageIII |
| 1126 | |
| 1127 | if_II = IFStageIII(device="cpu", dir_or_name=orig_path) |
| 1128 | else: |
| 1129 | assert f"Weird name. Should have -II- or -III- in path: {orig_path}" |
| 1130 | |
| 1131 | unet = UNet2DConditionModel(**unet_diffusers_config) |
| 1132 | |
| 1133 | # in params |
| 1134 | assert_param_count(unet.time_embedding, if_II.model.time_embed) |
| 1135 | assert_param_count(unet.conv_in, if_II.model.input_blocks[:1]) |
| 1136 | |
| 1137 | # downblocks |
| 1138 | assert_param_count(unet.down_blocks[0], if_II.model.input_blocks[1:4]) |
| 1139 | assert_param_count(unet.down_blocks[1], if_II.model.input_blocks[4:7]) |
| 1140 | assert_param_count(unet.down_blocks[2], if_II.model.input_blocks[7:11]) |
| 1141 | |
| 1142 | if "-II-" in orig_path: |
| 1143 | assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:17]) |
| 1144 | assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[17:]) |
| 1145 | if "-III-" in orig_path: |
| 1146 | assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:15]) |
| 1147 | assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[15:20]) |
| 1148 | assert_param_count(unet.down_blocks[5], if_II.model.input_blocks[20:]) |
| 1149 | |
| 1150 | # mid block |
| 1151 | assert_param_count(unet.mid_block, if_II.model.middle_block) |
| 1152 | |
| 1153 | # up block |
| 1154 | if "-II-" in orig_path: |
| 1155 | assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:6]) |
| 1156 | assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[6:12]) |
| 1157 | assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[12:16]) |
| 1158 | assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[16:19]) |
| 1159 | assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[19:]) |
| 1160 | if "-III-" in orig_path: |
| 1161 | assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:5]) |
| 1162 | assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[5:10]) |
| 1163 | assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[10:14]) |
| 1164 | assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[14:18]) |
| 1165 | assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[18:21]) |
| 1166 | assert_param_count(unet.up_blocks[5], if_II.model.output_blocks[21:24]) |
| 1167 | |
| 1168 | # out params |
| 1169 | assert_param_count(unet.conv_norm_out, if_II.model.out[0]) |
| 1170 | assert_param_count(unet.conv_out, if_II.model.out[2]) |
| 1171 | |
| 1172 | # make sure all model architecture has same param count |
| 1173 | assert_param_count(unet, if_II.model) |
| 1174 | |
| 1175 | |
| 1176 | def assert_param_count(model_1, model_2): |
no test coverage detected
searching dependent graphs…