MCPcopy Index your code
hub / github.com/huggingface/diffusers / save_new_embed

Function save_new_embed

examples/custom_diffusion/train_custom_diffusion.py:311–324  ·  view source on GitHub ↗

Saves the new token embeddings from the text encoder.

(text_encoder, modifier_token_id, accelerator, args, output_dir, safe_serialization=True)

Source from the content-addressed store, hash-verified

309
310
311def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir, safe_serialization=True):
312 """Saves the new token embeddings from the text encoder."""
313 logger.info("Saving embeddings")
314 learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
315 for x, y in zip(modifier_token_id, args.modifier_token):
316 learned_embeds_dict = {}
317 learned_embeds_dict[y] = learned_embeds[x]
318
319 if safe_serialization:
320 filename = f"{output_dir}/{y}.safetensors"
321 safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"})
322 else:
323 filename = f"{output_dir}/{y}.bin"
324 torch.save(learned_embeds_dict, filename)
325
326
327def parse_args(input_args=None):

Callers 1

mainFunction · 0.85

Calls 3

infoMethod · 0.80
saveMethod · 0.80
get_input_embeddingsMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…