MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / share_embedding

Function share_embedding

tensorrt_llm/models/modeling_utils.py:1529–1568  ·  view source on GitHub ↗
(model: PretrainedModel)

Source from the content-addressed store, hash-verified

1527
1528
1529def share_embedding(model: PretrainedModel) -> PretrainedModel:
1530 lm_head = None
1531 vocab_embedding = None
1532 for name, layer in model.named_modules():
1533 layer_name = name.rsplit('.', 1)[-1]
1534 if layer_name == "lm_head":
1535 lm_head = layer
1536 if layer_name == "vocab_embedding":
1537 vocab_embedding = layer
1538 if lm_head is not None and vocab_embedding is not None:
1539 break
1540
1541 # Cannot find either lm_head or vocab_embedding, e.g., pipeline parallel
1542 if lm_head is None or vocab_embedding is None:
1543 return model
1544
1545 # lm_head and vocab_embedding have different shapes, e.g., tensor parallel without embedding parallel
1546 if lm_head.weight.shape != vocab_embedding.weight.shape:
1547 return model
1548
1549 # lm_head can have a different type if quantized
1550 if lm_head.weight.dtype != vocab_embedding.weight.dtype:
1551 return model
1552
1553 # Don't assume weight can be shared if vocab_embedding is not initialized, e.g., dummy weights
1554 if not vocab_embedding.weight.is_inited():
1555 return model
1556
1557 if lm_head.weight.is_inited():
1558 lm_head_weight = numpy_to_torch(lm_head.weight.raw_value)
1559 vocab_embed_weight = numpy_to_torch(vocab_embedding.weight.raw_value)
1560 # The lm_head and vocab_embedding have different weights
1561 if (lm_head_weight - vocab_embed_weight).abs().max().item() > 1e-6:
1562 return model
1563
1564 lm_head.weight = vocab_embedding.weight
1565 if getattr(lm_head, 'per_channel_scale', None) and getattr(
1566 vocab_embedding, 'per_channel_scale', None):
1567 lm_head.per_channel_scale = vocab_embedding.per_token_scale
1568 return model
1569
1570
1571def set_fp8_context_fhma(model: PretrainedModel) -> PretrainedModel:

Callers 1

optimize_modelFunction · 0.85

Calls 5

numpy_to_torchFunction · 0.85
named_modulesMethod · 0.80
is_initedMethod · 0.80
absMethod · 0.80
maxMethod · 0.45

Tested by

no test coverage detected