Saves the new token embeddings from the text encoder.
(text_encoder, modifier_token_id, accelerator, args, output_dir, safe_serialization=True)
| 309 | |
| 310 | |
| 311 | def save_new_embed(text_encoder, modifier_token_id, accelerator, args, output_dir, safe_serialization=True): |
| 312 | """Saves the new token embeddings from the text encoder.""" |
| 313 | logger.info("Saving embeddings") |
| 314 | learned_embeds = accelerator.unwrap_model(text_encoder).get_input_embeddings().weight |
| 315 | for x, y in zip(modifier_token_id, args.modifier_token): |
| 316 | learned_embeds_dict = {} |
| 317 | learned_embeds_dict[y] = learned_embeds[x] |
| 318 | |
| 319 | if safe_serialization: |
| 320 | filename = f"{output_dir}/{y}.safetensors" |
| 321 | safetensors.torch.save_file(learned_embeds_dict, filename, metadata={"format": "pt"}) |
| 322 | else: |
| 323 | filename = f"{output_dir}/{y}.bin" |
| 324 | torch.save(learned_embeds_dict, filename) |
| 325 | |
| 326 | |
| 327 | def parse_args(input_args=None): |
no test coverage detected
searching dependent graphs…