(hor)
| 13 | |
| 14 | |
| 15 | def unet(hor): |
| 16 | if hor == 128: |
| 17 | down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") |
| 18 | block_out_channels = (32, 128, 256) |
| 19 | up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D") |
| 20 | |
| 21 | elif hor == 32: |
| 22 | down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") |
| 23 | block_out_channels = (32, 64, 128, 256) |
| 24 | up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D") |
| 25 | model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch") |
| 26 | state_dict = model.state_dict() |
| 27 | config = { |
| 28 | "down_block_types": down_block_types, |
| 29 | "block_out_channels": block_out_channels, |
| 30 | "up_block_types": up_block_types, |
| 31 | "layers_per_block": 1, |
| 32 | "use_timestep_embedding": True, |
| 33 | "out_block_type": "OutConv1DBlock", |
| 34 | "norm_num_groups": 8, |
| 35 | "downsample_each_block": False, |
| 36 | "in_channels": 14, |
| 37 | "out_channels": 14, |
| 38 | "extra_in_channels": 0, |
| 39 | "time_embedding_type": "positional", |
| 40 | "flip_sin_to_cos": False, |
| 41 | "freq_shift": 1, |
| 42 | "sample_size": 65536, |
| 43 | "mid_block_type": "MidResTemporalBlock1D", |
| 44 | "act_fn": "mish", |
| 45 | } |
| 46 | hf_value_function = UNet1DModel(**config) |
| 47 | print(f"length of state dict: {len(state_dict.keys())}") |
| 48 | print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") |
| 49 | mapping = dict(zip(model.state_dict().keys(), hf_value_function.state_dict().keys())) |
| 50 | for k, v in mapping.items(): |
| 51 | state_dict[v] = state_dict.pop(k) |
| 52 | hf_value_function.load_state_dict(state_dict) |
| 53 | |
| 54 | torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin") |
| 55 | with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f: |
| 56 | json.dump(config, f) |
| 57 | |
| 58 | |
| 59 | def value_function(): |
searching dependent graphs…