(args)
| 257 | |
| 258 | |
| 259 | def main(args): |
| 260 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 261 | |
| 262 | model_name = args.model_path.split("/")[-1].split(".")[0] |
| 263 | if not os.path.isfile(args.model_path): |
| 264 | assert model_name == args.model_path, ( |
| 265 | f"Make sure to provide one of the official model names {MODELS_MAP.keys()}" |
| 266 | ) |
| 267 | args.model_path = download(model_name) |
| 268 | |
| 269 | sample_rate = MODELS_MAP[model_name]["sample_rate"] |
| 270 | sample_size = MODELS_MAP[model_name]["sample_size"] |
| 271 | |
| 272 | config = Object() |
| 273 | config.sample_size = sample_size |
| 274 | config.sample_rate = sample_rate |
| 275 | config.latent_dim = 0 |
| 276 | |
| 277 | diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate) |
| 278 | diffusers_state_dict = diffusers_model.state_dict() |
| 279 | |
| 280 | orig_model = DiffusionUncond(config) |
| 281 | orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"]) |
| 282 | orig_model = orig_model.diffusion_ema.eval() |
| 283 | orig_model_state_dict = orig_model.state_dict() |
| 284 | renamed_state_dict = rename_orig_weights(orig_model_state_dict) |
| 285 | |
| 286 | renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys()) |
| 287 | diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys()) |
| 288 | |
| 289 | assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}" |
| 290 | assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}" |
| 291 | |
| 292 | for key, value in renamed_state_dict.items(): |
| 293 | assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, ( |
| 294 | f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}" |
| 295 | ) |
| 296 | if key == "time_proj.weight": |
| 297 | value = value.squeeze() |
| 298 | |
| 299 | diffusers_state_dict[key] = value |
| 300 | |
| 301 | diffusers_model.load_state_dict(diffusers_state_dict) |
| 302 | |
| 303 | steps = 100 |
| 304 | seed = 33 |
| 305 | |
| 306 | diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps) |
| 307 | |
| 308 | generator = torch.manual_seed(seed) |
| 309 | noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device) |
| 310 | |
| 311 | t = torch.linspace(1, 0, steps + 1, device=device)[:-1] |
| 312 | step_list = get_crash_schedule(t) |
| 313 | |
| 314 | pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler) |
| 315 | |
| 316 | generator = torch.manual_seed(33) |
no test coverage detected
searching dependent graphs…