MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / embedding

Function embedding

tensorrt_llm/functional.py:2717–2851  ·  view source on GitHub ↗

Add an operation to perform embedding lookup. That operation performs the embedding lookup. The 'input' tensor contains the identifiers of the rows of 'weight' to gather. 1. Distribute the embedding lookup table over multiple GPU When 'tp_size' is greater than 1 and the 'tp_gr

(input: Tensor,
              weight: Tensor,
              tp_size=1,
              tp_group=None,
              sharding_dim=0,
              tp_rank=None,
              per_token_scale=None,
              padding=None)

Source from the content-addressed store, hash-verified

2715
2716
2717def embedding(input: Tensor,
2718 weight: Tensor,
2719 tp_size=1,
2720 tp_group=None,
2721 sharding_dim=0,
2722 tp_rank=None,
2723 per_token_scale=None,
2724 padding=None) -> Tensor:
2725 '''
2726 Add an operation to perform embedding lookup.
2727
2728 That operation performs the embedding lookup. The 'input' tensor contains
2729 the identifiers of the rows of 'weight' to gather.
2730
2731 1. Distribute the embedding lookup table over multiple GPU
2732 When 'tp_size' is greater than 1 and the 'tp_group' is defined, this
2733 embedding lookup is distributed among multiple GPUs.
2734
2735 When 'sharding_dim==0', each GPU stores a subset of the rows of the embedding
2736 table rows(that number of rows per GPU is given by weights.shape[0] and the offset to
2737 the 1st row stored on the GPU is given by rank * weights.shape[0]). Each
2738 parallel rank will query all the indices and set 0s for the weights that
2739 are not stored on the associated GPU. To compute the final result, a
2740 parallel all-reduce operation is added to the TensorRT graph. That lookup
2741 can be performed using either the plugin or the operators TensorRT support.
2742
2743 When'sharding_dim==1', each GPU stores a subset of the embedding table's columns.
2744 Each rank can obtain a portion of the embedding results.
2745 Then the embedding is collected using the all-gather operation.
2746 Related transposition operations are also used to obtain the final results.
2747
2748 2. Store embedding lookup table as a whole
2749 When 'tp_size' is not greater than 1, the embedding lookup table will not
2750 be divided. In this case, when the default_net().plugin_config.lookup_plugin is set,
2751 the operation is implemented using a plugin (without the all-reduce operation).
2752 Otherwise, this operation is implemented using the standard IGatherLayer in TensorRT.
2753
2754 Parameters:
2755 input : Tensor
2756 The input tensor the contains the indices to perform the lookup.
2757
2758 weight : Tensor
2759 The table to gather from.
2760
2761 tp_size : int
2762 The number of GPUs collaborating to perform that embedding.
2763
2764 tg_group : Optional[List[int]]
2765 The group of world ranks participating in the all-reduce when
2766 tp_size > 1.
2767
2768 sharding_dim : int
2769 sharding_dim = 0 means that we shard the embedding table in vocab dim;
2770 sharding_dim = 1 means that we shard the embedding table in embedding dim.
2771
2772 tp_rank : int
2773 The tensor parallelism rank. Used to calculate offset in TP on vocab dim.
2774

Callers 11

test_triton_pluginFunction · 0.85
test_embeddingMethod · 0.85
run_engineFunction · 0.85
forwardMethod · 0.85
compute_relative_biasFunction · 0.85
forwardMethod · 0.85
forwardMethod · 0.85
_beam_search_candidatesFunction · 0.85

Calls 13

concatFunction · 0.85
_lookup_pluginFunction · 0.85
sliceFunction · 0.85
expand_dimsFunction · 0.85
whereFunction · 0.85
default_trtnetFunction · 0.85
_create_tensorFunction · 0.85
castFunction · 0.85
allreduceFunction · 0.70
shapeFunction · 0.70
allgatherFunction · 0.70
ndimMethod · 0.45

Tested by 3

test_triton_pluginFunction · 0.68
test_embeddingMethod · 0.68