MCPcopy Index your code
hub / github.com/huggingface/diffusers / verify_param_count

Function verify_param_count

scripts/convert_if.py:1119–1173  ·  view source on GitHub ↗
(orig_path, unet_diffusers_config)

Source from the content-addressed store, hash-verified

1117
1118
1119def verify_param_count(orig_path, unet_diffusers_config):
1120 if "-II-" in orig_path:
1121 from deepfloyd_if.modules import IFStageII
1122
1123 if_II = IFStageII(device="cpu", dir_or_name=orig_path)
1124 elif "-III-" in orig_path:
1125 from deepfloyd_if.modules import IFStageIII
1126
1127 if_II = IFStageIII(device="cpu", dir_or_name=orig_path)
1128 else:
1129 assert f"Weird name. Should have -II- or -III- in path: {orig_path}"
1130
1131 unet = UNet2DConditionModel(**unet_diffusers_config)
1132
1133 # in params
1134 assert_param_count(unet.time_embedding, if_II.model.time_embed)
1135 assert_param_count(unet.conv_in, if_II.model.input_blocks[:1])
1136
1137 # downblocks
1138 assert_param_count(unet.down_blocks[0], if_II.model.input_blocks[1:4])
1139 assert_param_count(unet.down_blocks[1], if_II.model.input_blocks[4:7])
1140 assert_param_count(unet.down_blocks[2], if_II.model.input_blocks[7:11])
1141
1142 if "-II-" in orig_path:
1143 assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:17])
1144 assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[17:])
1145 if "-III-" in orig_path:
1146 assert_param_count(unet.down_blocks[3], if_II.model.input_blocks[11:15])
1147 assert_param_count(unet.down_blocks[4], if_II.model.input_blocks[15:20])
1148 assert_param_count(unet.down_blocks[5], if_II.model.input_blocks[20:])
1149
1150 # mid block
1151 assert_param_count(unet.mid_block, if_II.model.middle_block)
1152
1153 # up block
1154 if "-II-" in orig_path:
1155 assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:6])
1156 assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[6:12])
1157 assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[12:16])
1158 assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[16:19])
1159 assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[19:])
1160 if "-III-" in orig_path:
1161 assert_param_count(unet.up_blocks[0], if_II.model.output_blocks[:5])
1162 assert_param_count(unet.up_blocks[1], if_II.model.output_blocks[5:10])
1163 assert_param_count(unet.up_blocks[2], if_II.model.output_blocks[10:14])
1164 assert_param_count(unet.up_blocks[3], if_II.model.output_blocks[14:18])
1165 assert_param_count(unet.up_blocks[4], if_II.model.output_blocks[18:21])
1166 assert_param_count(unet.up_blocks[5], if_II.model.output_blocks[21:24])
1167
1168 # out params
1169 assert_param_count(unet.conv_norm_out, if_II.model.out[0])
1170 assert_param_count(unet.conv_out, if_II.model.out[2])
1171
1172 # make sure all model architecture has same param count
1173 assert_param_count(unet, if_II.model)
1174
1175
1176def assert_param_count(model_1, model_2):

Callers 1

get_super_res_unetFunction · 0.85

Calls 2

assert_param_countFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…