MCPcopy Index your code
hub / github.com/huggingface/diffusers / main

Function main

scripts/convert_dance_diffusion_to_diffusers.py:259–333  ·  view source on GitHub ↗
(args)

Source from the content-addressed store, hash-verified

257
258
259def main(args):
260 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
261
262 model_name = args.model_path.split("/")[-1].split(".")[0]
263 if not os.path.isfile(args.model_path):
264 assert model_name == args.model_path, (
265 f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
266 )
267 args.model_path = download(model_name)
268
269 sample_rate = MODELS_MAP[model_name]["sample_rate"]
270 sample_size = MODELS_MAP[model_name]["sample_size"]
271
272 config = Object()
273 config.sample_size = sample_size
274 config.sample_rate = sample_rate
275 config.latent_dim = 0
276
277 diffusers_model = UNet1DModel(sample_size=sample_size, sample_rate=sample_rate)
278 diffusers_state_dict = diffusers_model.state_dict()
279
280 orig_model = DiffusionUncond(config)
281 orig_model.load_state_dict(torch.load(args.model_path, map_location=device)["state_dict"])
282 orig_model = orig_model.diffusion_ema.eval()
283 orig_model_state_dict = orig_model.state_dict()
284 renamed_state_dict = rename_orig_weights(orig_model_state_dict)
285
286 renamed_minus_diffusers = set(renamed_state_dict.keys()) - set(diffusers_state_dict.keys())
287 diffusers_minus_renamed = set(diffusers_state_dict.keys()) - set(renamed_state_dict.keys())
288
289 assert len(renamed_minus_diffusers) == 0, f"Problem with {renamed_minus_diffusers}"
290 assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
291
292 for key, value in renamed_state_dict.items():
293 assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (
294 f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
295 )
296 if key == "time_proj.weight":
297 value = value.squeeze()
298
299 diffusers_state_dict[key] = value
300
301 diffusers_model.load_state_dict(diffusers_state_dict)
302
303 steps = 100
304 seed = 33
305
306 diffusers_scheduler = IPNDMScheduler(num_train_timesteps=steps)
307
308 generator = torch.manual_seed(seed)
309 noise = torch.randn([1, 2, config.sample_size], generator=generator).to(device)
310
311 t = torch.linspace(1, 0, steps + 1, device=device)[:-1]
312 step_list = get_crash_schedule(t)
313
314 pipe = DanceDiffusionPipeline(unet=diffusers_model, scheduler=diffusers_scheduler)
315
316 generator = torch.manual_seed(33)

Calls 15

UNet1DModelClass · 0.90
IPNDMSchedulerClass · 0.90
downloadFunction · 0.85
ObjectClass · 0.85
DiffusionUncondClass · 0.85
rename_orig_weightsFunction · 0.85
get_crash_scheduleFunction · 0.85
pipeFunction · 0.85
splitMethod · 0.80
deviceMethod · 0.45
state_dictMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…