MCPcopy Index your code
hub / github.com/huggingface/diffusers / make_vqvae

Function make_vqvae

scripts/convert_amused.py:346–473  ·  view source on GitHub ↗
(old_vae)

Source from the content-addressed store, hash-verified

344
345
346def make_vqvae(old_vae):
347 new_vae = VQModel(
348 act_fn="silu",
349 block_out_channels=[128, 256, 256, 512, 768],
350 down_block_types=[
351 "DownEncoderBlock2D",
352 "DownEncoderBlock2D",
353 "DownEncoderBlock2D",
354 "DownEncoderBlock2D",
355 "DownEncoderBlock2D",
356 ],
357 in_channels=3,
358 latent_channels=64,
359 layers_per_block=2,
360 norm_num_groups=32,
361 num_vq_embeddings=8192,
362 out_channels=3,
363 sample_size=32,
364 up_block_types=[
365 "UpDecoderBlock2D",
366 "UpDecoderBlock2D",
367 "UpDecoderBlock2D",
368 "UpDecoderBlock2D",
369 "UpDecoderBlock2D",
370 ],
371 mid_block_add_attention=False,
372 lookup_from_codebook=True,
373 )
374 new_vae.to(device)
375
376 # fmt: off
377
378 new_state_dict = {}
379
380 old_state_dict = old_vae.state_dict()
381
382 new_state_dict["encoder.conv_in.weight"] = old_state_dict.pop("encoder.conv_in.weight")
383 new_state_dict["encoder.conv_in.bias"] = old_state_dict.pop("encoder.conv_in.bias")
384
385 convert_vae_block_state_dict(old_state_dict, "encoder.down.0", new_state_dict, "encoder.down_blocks.0")
386 convert_vae_block_state_dict(old_state_dict, "encoder.down.1", new_state_dict, "encoder.down_blocks.1")
387 convert_vae_block_state_dict(old_state_dict, "encoder.down.2", new_state_dict, "encoder.down_blocks.2")
388 convert_vae_block_state_dict(old_state_dict, "encoder.down.3", new_state_dict, "encoder.down_blocks.3")
389 convert_vae_block_state_dict(old_state_dict, "encoder.down.4", new_state_dict, "encoder.down_blocks.4")
390
391 new_state_dict["encoder.mid_block.resnets.0.norm1.weight"] = old_state_dict.pop("encoder.mid.block_1.norm1.weight")
392 new_state_dict["encoder.mid_block.resnets.0.norm1.bias"] = old_state_dict.pop("encoder.mid.block_1.norm1.bias")
393 new_state_dict["encoder.mid_block.resnets.0.conv1.weight"] = old_state_dict.pop("encoder.mid.block_1.conv1.weight")
394 new_state_dict["encoder.mid_block.resnets.0.conv1.bias"] = old_state_dict.pop("encoder.mid.block_1.conv1.bias")
395 new_state_dict["encoder.mid_block.resnets.0.norm2.weight"] = old_state_dict.pop("encoder.mid.block_1.norm2.weight")
396 new_state_dict["encoder.mid_block.resnets.0.norm2.bias"] = old_state_dict.pop("encoder.mid.block_1.norm2.bias")
397 new_state_dict["encoder.mid_block.resnets.0.conv2.weight"] = old_state_dict.pop("encoder.mid.block_1.conv2.weight")
398 new_state_dict["encoder.mid_block.resnets.0.conv2.bias"] = old_state_dict.pop("encoder.mid.block_1.conv2.bias")
399 new_state_dict["encoder.mid_block.resnets.1.norm1.weight"] = old_state_dict.pop("encoder.mid.block_2.norm1.weight")
400 new_state_dict["encoder.mid_block.resnets.1.norm1.bias"] = old_state_dict.pop("encoder.mid.block_2.norm1.bias")
401 new_state_dict["encoder.mid_block.resnets.1.conv1.weight"] = old_state_dict.pop("encoder.mid.block_2.conv1.weight")
402 new_state_dict["encoder.mid_block.resnets.1.conv1.bias"] = old_state_dict.pop("encoder.mid.block_2.conv1.bias")
403 new_state_dict["encoder.mid_block.resnets.1.norm2.weight"] = old_state_dict.pop("encoder.mid.block_2.norm2.weight")

Callers 1

mainFunction · 0.85

Calls 7

VQModelClass · 0.90
floatMethod · 0.80
toMethod · 0.45
state_dictMethod · 0.45
popMethod · 0.45
load_state_dictMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…