(sd, transformer_config, dtype=torch.bfloat16)
| 418 | |
| 419 | |
| 420 | def load_audio_embeddings_connector(sd, transformer_config, dtype=torch.bfloat16): |
| 421 | rope_type = LTXRopeType.from_dict(transformer_config) |
| 422 | frequencies_precision = LTXFrequenciesPrecision.from_dict(transformer_config) |
| 423 | pe_max_pos = transformer_config.get("connector_positional_embedding_max_pos", [1]) |
| 424 | |
| 425 | connector_config = { |
| 426 | "num_attention_heads": transformer_config.get( |
| 427 | "audio_connector_num_attention_heads", |
| 428 | transformer_config.get("connector_num_attention_heads", 30), |
| 429 | ), |
| 430 | "attention_head_dim": transformer_config.get( |
| 431 | "audio_connector_attention_head_dim", |
| 432 | transformer_config.get("connector_attention_head_dim", 128), |
| 433 | ), |
| 434 | "num_layers": transformer_config.get( |
| 435 | "audio_connector_num_layers", |
| 436 | transformer_config.get("connector_num_layers", 2), |
| 437 | ), |
| 438 | "apply_gated_attention": transformer_config.get( |
| 439 | "connector_apply_gated_attention", False |
| 440 | ), |
| 441 | } |
| 442 | |
| 443 | return load_embeddings_connector( |
| 444 | sd, |
| 445 | f"{_PREFIX_BASE}audio_embeddings_connector.", |
| 446 | connector_config, |
| 447 | dtype, |
| 448 | rope_type, |
| 449 | frequencies_precision, |
| 450 | pe_max_pos, |
| 451 | ) |
no test coverage detected