MCPcopy
hub / github.com/huggingface/diffusers / unet

Function unet

scripts/convert_models_diffuser_to_diffusers.py:15–56  ·  view source on GitHub ↗
(hor)

Source from the content-addressed store, hash-verified

13
14
15def unet(hor):
16 if hor == 128:
17 down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
18 block_out_channels = (32, 128, 256)
19 up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D")
20
21 elif hor == 32:
22 down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D")
23 block_out_channels = (32, 64, 128, 256)
24 up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D")
25 model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch")
26 state_dict = model.state_dict()
27 config = {
28 "down_block_types": down_block_types,
29 "block_out_channels": block_out_channels,
30 "up_block_types": up_block_types,
31 "layers_per_block": 1,
32 "use_timestep_embedding": True,
33 "out_block_type": "OutConv1DBlock",
34 "norm_num_groups": 8,
35 "downsample_each_block": False,
36 "in_channels": 14,
37 "out_channels": 14,
38 "extra_in_channels": 0,
39 "time_embedding_type": "positional",
40 "flip_sin_to_cos": False,
41 "freq_shift": 1,
42 "sample_size": 65536,
43 "mid_block_type": "MidResTemporalBlock1D",
44 "act_fn": "mish",
45 }
46 hf_value_function = UNet1DModel(**config)
47 print(f"length of state dict: {len(state_dict.keys())}")
48 print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
49 mapping = dict(zip(model.state_dict().keys(), hf_value_function.state_dict().keys()))
50 for k, v in mapping.items():
51 state_dict[v] = state_dict.pop(k)
52 hf_value_function.load_state_dict(state_dict)
53
54 torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin")
55 with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f:
56 json.dump(config, f)
57
58
59def value_function():

Callers 15

__call__Method · 0.85
__call__Method · 0.85
test_serializationMethod · 0.85
test_serializationMethod · 0.85
mainFunction · 0.85

Calls 6

UNet1DModelClass · 0.90
saveMethod · 0.80
loadMethod · 0.45
state_dictMethod · 0.45
popMethod · 0.45
load_state_dictMethod · 0.45

Used in the wild real call sites across dependent graphs

searching dependent graphs…