Add an operation to perform embedding lookup. That operation performs the embedding lookup. The 'input' tensor contains the identifiers of the rows of 'weight' to gather. 1. Distribute the embedding lookup table over multiple GPU When 'tp_size' is greater than 1 and the 'tp_gr
(input: Tensor,
weight: Tensor,
tp_size=1,
tp_group=None,
sharding_dim=0,
tp_rank=None,
per_token_scale=None,
padding=None)
| 2715 | |
| 2716 | |
| 2717 | def embedding(input: Tensor, |
| 2718 | weight: Tensor, |
| 2719 | tp_size=1, |
| 2720 | tp_group=None, |
| 2721 | sharding_dim=0, |
| 2722 | tp_rank=None, |
| 2723 | per_token_scale=None, |
| 2724 | padding=None) -> Tensor: |
| 2725 | ''' |
| 2726 | Add an operation to perform embedding lookup. |
| 2727 | |
| 2728 | That operation performs the embedding lookup. The 'input' tensor contains |
| 2729 | the identifiers of the rows of 'weight' to gather. |
| 2730 | |
| 2731 | 1. Distribute the embedding lookup table over multiple GPU |
| 2732 | When 'tp_size' is greater than 1 and the 'tp_group' is defined, this |
| 2733 | embedding lookup is distributed among multiple GPUs. |
| 2734 | |
| 2735 | When 'sharding_dim==0', each GPU stores a subset of the rows of the embedding |
| 2736 | table rows(that number of rows per GPU is given by weights.shape[0] and the offset to |
| 2737 | the 1st row stored on the GPU is given by rank * weights.shape[0]). Each |
| 2738 | parallel rank will query all the indices and set 0s for the weights that |
| 2739 | are not stored on the associated GPU. To compute the final result, a |
| 2740 | parallel all-reduce operation is added to the TensorRT graph. That lookup |
| 2741 | can be performed using either the plugin or the operators TensorRT support. |
| 2742 | |
| 2743 | When'sharding_dim==1', each GPU stores a subset of the embedding table's columns. |
| 2744 | Each rank can obtain a portion of the embedding results. |
| 2745 | Then the embedding is collected using the all-gather operation. |
| 2746 | Related transposition operations are also used to obtain the final results. |
| 2747 | |
| 2748 | 2. Store embedding lookup table as a whole |
| 2749 | When 'tp_size' is not greater than 1, the embedding lookup table will not |
| 2750 | be divided. In this case, when the default_net().plugin_config.lookup_plugin is set, |
| 2751 | the operation is implemented using a plugin (without the all-reduce operation). |
| 2752 | Otherwise, this operation is implemented using the standard IGatherLayer in TensorRT. |
| 2753 | |
| 2754 | Parameters: |
| 2755 | input : Tensor |
| 2756 | The input tensor the contains the indices to perform the lookup. |
| 2757 | |
| 2758 | weight : Tensor |
| 2759 | The table to gather from. |
| 2760 | |
| 2761 | tp_size : int |
| 2762 | The number of GPUs collaborating to perform that embedding. |
| 2763 | |
| 2764 | tg_group : Optional[List[int]] |
| 2765 | The group of world ranks participating in the all-reduce when |
| 2766 | tp_size > 1. |
| 2767 | |
| 2768 | sharding_dim : int |
| 2769 | sharding_dim = 0 means that we shard the embedding table in vocab dim; |
| 2770 | sharding_dim = 1 means that we shard the embedding table in embedding dim. |
| 2771 | |
| 2772 | tp_rank : int |
| 2773 | The tensor parallelism rank. Used to calculate offset in TP on vocab dim. |
| 2774 |