(model: PretrainedModel)
| 1527 | |
| 1528 | |
| 1529 | def share_embedding(model: PretrainedModel) -> PretrainedModel: |
| 1530 | lm_head = None |
| 1531 | vocab_embedding = None |
| 1532 | for name, layer in model.named_modules(): |
| 1533 | layer_name = name.rsplit('.', 1)[-1] |
| 1534 | if layer_name == "lm_head": |
| 1535 | lm_head = layer |
| 1536 | if layer_name == "vocab_embedding": |
| 1537 | vocab_embedding = layer |
| 1538 | if lm_head is not None and vocab_embedding is not None: |
| 1539 | break |
| 1540 | |
| 1541 | # Cannot find either lm_head or vocab_embedding, e.g., pipeline parallel |
| 1542 | if lm_head is None or vocab_embedding is None: |
| 1543 | return model |
| 1544 | |
| 1545 | # lm_head and vocab_embedding have different shapes, e.g., tensor parallel without embedding parallel |
| 1546 | if lm_head.weight.shape != vocab_embedding.weight.shape: |
| 1547 | return model |
| 1548 | |
| 1549 | # lm_head can have a different type if quantized |
| 1550 | if lm_head.weight.dtype != vocab_embedding.weight.dtype: |
| 1551 | return model |
| 1552 | |
| 1553 | # Don't assume weight can be shared if vocab_embedding is not initialized, e.g., dummy weights |
| 1554 | if not vocab_embedding.weight.is_inited(): |
| 1555 | return model |
| 1556 | |
| 1557 | if lm_head.weight.is_inited(): |
| 1558 | lm_head_weight = numpy_to_torch(lm_head.weight.raw_value) |
| 1559 | vocab_embed_weight = numpy_to_torch(vocab_embedding.weight.raw_value) |
| 1560 | # The lm_head and vocab_embedding have different weights |
| 1561 | if (lm_head_weight - vocab_embed_weight).abs().max().item() > 1e-6: |
| 1562 | return model |
| 1563 | |
| 1564 | lm_head.weight = vocab_embedding.weight |
| 1565 | if getattr(lm_head, 'per_channel_scale', None) and getattr( |
| 1566 | vocab_embedding, 'per_channel_scale', None): |
| 1567 | lm_head.per_channel_scale = vocab_embedding.per_token_scale |
| 1568 | return model |
| 1569 | |
| 1570 | |
| 1571 | def set_fp8_context_fhma(model: PretrainedModel) -> PretrainedModel: |
no test coverage detected