(old_vae)
| 344 | |
| 345 | |
| 346 | def make_vqvae(old_vae): |
| 347 | new_vae = VQModel( |
| 348 | act_fn="silu", |
| 349 | block_out_channels=[128, 256, 256, 512, 768], |
| 350 | down_block_types=[ |
| 351 | "DownEncoderBlock2D", |
| 352 | "DownEncoderBlock2D", |
| 353 | "DownEncoderBlock2D", |
| 354 | "DownEncoderBlock2D", |
| 355 | "DownEncoderBlock2D", |
| 356 | ], |
| 357 | in_channels=3, |
| 358 | latent_channels=64, |
| 359 | layers_per_block=2, |
| 360 | norm_num_groups=32, |
| 361 | num_vq_embeddings=8192, |
| 362 | out_channels=3, |
| 363 | sample_size=32, |
| 364 | up_block_types=[ |
| 365 | "UpDecoderBlock2D", |
| 366 | "UpDecoderBlock2D", |
| 367 | "UpDecoderBlock2D", |
| 368 | "UpDecoderBlock2D", |
| 369 | "UpDecoderBlock2D", |
| 370 | ], |
| 371 | mid_block_add_attention=False, |
| 372 | lookup_from_codebook=True, |
| 373 | ) |
| 374 | new_vae.to(device) |
| 375 | |
| 376 | # fmt: off |
| 377 | |
| 378 | new_state_dict = {} |
| 379 | |
| 380 | old_state_dict = old_vae.state_dict() |
| 381 | |
| 382 | new_state_dict["encoder.conv_in.weight"] = old_state_dict.pop("encoder.conv_in.weight") |
| 383 | new_state_dict["encoder.conv_in.bias"] = old_state_dict.pop("encoder.conv_in.bias") |
| 384 | |
| 385 | convert_vae_block_state_dict(old_state_dict, "encoder.down.0", new_state_dict, "encoder.down_blocks.0") |
| 386 | convert_vae_block_state_dict(old_state_dict, "encoder.down.1", new_state_dict, "encoder.down_blocks.1") |
| 387 | convert_vae_block_state_dict(old_state_dict, "encoder.down.2", new_state_dict, "encoder.down_blocks.2") |
| 388 | convert_vae_block_state_dict(old_state_dict, "encoder.down.3", new_state_dict, "encoder.down_blocks.3") |
| 389 | convert_vae_block_state_dict(old_state_dict, "encoder.down.4", new_state_dict, "encoder.down_blocks.4") |
| 390 | |
| 391 | new_state_dict["encoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("encoder.mid.block_1.norm1.weight") |
| 392 | new_state_dict["encoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("encoder.mid.block_1.norm1.bias") |
| 393 | new_state_dict["encoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("encoder.mid.block_1.conv1.weight") |
| 394 | new_state_dict["encoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("encoder.mid.block_1.conv1.bias") |
| 395 | new_state_dict["encoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("encoder.mid.block_1.norm2.weight") |
| 396 | new_state_dict["encoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("encoder.mid.block_1.norm2.bias") |
| 397 | new_state_dict["encoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("encoder.mid.block_1.conv2.weight") |
| 398 | new_state_dict["encoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("encoder.mid.block_1.conv2.bias") |
| 399 | new_state_dict["encoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("encoder.mid.block_2.norm1.weight") |
| 400 | new_state_dict["encoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("encoder.mid.block_2.norm1.bias") |
| 401 | new_state_dict["encoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("encoder.mid.block_2.conv1.weight") |
| 402 | new_state_dict["encoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("encoder.mid.block_2.conv1.bias") |
| 403 | new_state_dict["encoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("encoder.mid.block_2.norm2.weight") |
no test coverage detected
searching dependent graphs…