()
| 57 | |
| 58 | |
| 59 | def value_function(): |
| 60 | config = { |
| 61 | "in_channels": 14, |
| 62 | "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), |
| 63 | "up_block_types": (), |
| 64 | "out_block_type": "ValueFunction", |
| 65 | "mid_block_type": "ValueFunctionMidBlock1D", |
| 66 | "block_out_channels": (32, 64, 128, 256), |
| 67 | "layers_per_block": 1, |
| 68 | "downsample_each_block": True, |
| 69 | "sample_size": 65536, |
| 70 | "out_channels": 14, |
| 71 | "extra_in_channels": 0, |
| 72 | "time_embedding_type": "positional", |
| 73 | "use_timestep_embedding": True, |
| 74 | "flip_sin_to_cos": False, |
| 75 | "freq_shift": 1, |
| 76 | "norm_num_groups": 8, |
| 77 | "act_fn": "mish", |
| 78 | } |
| 79 | |
| 80 | model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch") |
| 81 | state_dict = model |
| 82 | hf_value_function = UNet1DModel(**config) |
| 83 | print(f"length of state dict: {len(state_dict.keys())}") |
| 84 | print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") |
| 85 | |
| 86 | mapping = dict(zip(state_dict.keys(), hf_value_function.state_dict().keys())) |
| 87 | for k, v in mapping.items(): |
| 88 | state_dict[v] = state_dict.pop(k) |
| 89 | |
| 90 | hf_value_function.load_state_dict(state_dict) |
| 91 | |
| 92 | torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin") |
| 93 | with open("hub/hopper-medium-v2/value_function/config.json", "w") as f: |
| 94 | json.dump(config, f) |
| 95 | |
| 96 | |
| 97 | if __name__ == "__main__": |
searching dependent graphs…