MCPcopy Index your code
hub / github.com/huggingface/diffusers / value_function

Function value_function

scripts/convert_models_diffuser_to_diffusers.py:59–94  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

57
58
59def value_function():
60 config = {
61 "in_channels": 14,
62 "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"),
63 "up_block_types": (),
64 "out_block_type": "ValueFunction",
65 "mid_block_type": "ValueFunctionMidBlock1D",
66 "block_out_channels": (32, 64, 128, 256),
67 "layers_per_block": 1,
68 "downsample_each_block": True,
69 "sample_size": 65536,
70 "out_channels": 14,
71 "extra_in_channels": 0,
72 "time_embedding_type": "positional",
73 "use_timestep_embedding": True,
74 "flip_sin_to_cos": False,
75 "freq_shift": 1,
76 "norm_num_groups": 8,
77 "act_fn": "mish",
78 }
79
80 model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch")
81 state_dict = model
82 hf_value_function = UNet1DModel(**config)
83 print(f"length of state dict: {len(state_dict.keys())}")
84 print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}")
85
86 mapping = dict(zip(state_dict.keys(), hf_value_function.state_dict().keys()))
87 for k, v in mapping.items():
88 state_dict[v] = state_dict.pop(k)
89
90 hf_value_function.load_state_dict(state_dict)
91
92 torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin")
93 with open("hub/hopper-medium-v2/value_function/config.json", "w") as f:
94 json.dump(config, f)
95
96
97if __name__ == "__main__":

Calls 6

UNet1DModelClass · 0.90
saveMethod · 0.80
loadMethod · 0.45
state_dictMethod · 0.45
popMethod · 0.45
load_state_dictMethod · 0.45

Tested by 2

Used in the wild real call sites across dependent graphs

searching dependent graphs…