Add an operation to perform lookup in a tensor. That operation performs the lookup needed by embedding layers. Given a 'weight' tensor of shape [rows, cols], it produces a tensor of shape [inputs.size(0), cols] where the ith row corresponds to the input[i] row in the weight ten
(input: Tensor, weight: Tensor, rank: int,
per_token_scale: Tensor)
| 2668 | |
| 2669 | |
| 2670 | def _lookup_plugin(input: Tensor, weight: Tensor, rank: int, |
| 2671 | per_token_scale: Tensor) -> Tensor: |
| 2672 | ''' |
| 2673 | Add an operation to perform lookup in a tensor. |
| 2674 | |
| 2675 | That operation performs the lookup needed by embedding layers. Given a |
| 2676 | 'weight' tensor of shape [rows, cols], it produces a tensor of shape |
| 2677 | [inputs.size(0), cols] where the ith row corresponds to the input[i] row in |
| 2678 | the weight tensor. |
| 2679 | |
| 2680 | It inserts a IPluginV2Layer. |
| 2681 | |
| 2682 | Parameters: |
| 2683 | input : Tensor |
| 2684 | The input tensor contains the indices to perform the lookup. |
| 2685 | |
| 2686 | weight : Tensor |
| 2687 | The table to gather from. |
| 2688 | |
| 2689 | rank : int |
| 2690 | The mpi rank. |
| 2691 | |
| 2692 | Returns: |
| 2693 | The output tensor of the lookup layer. |
| 2694 | ''' |
| 2695 | plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 2696 | 'Lookup', '1', TRT_LLM_PLUGIN_NAMESPACE) |
| 2697 | assert plg_creator is not None |
| 2698 | |
| 2699 | p_dtype = per_token_scale.dtype |
| 2700 | pf_type = trt.PluginField("type_id", np.array([int(p_dtype)], np.int32), |
| 2701 | trt.PluginFieldType.INT32) |
| 2702 | |
| 2703 | rank = trt.PluginField("rank", np.array([int(rank)], np.int32), |
| 2704 | trt.PluginFieldType.INT32) |
| 2705 | |
| 2706 | pfc = trt.PluginFieldCollection([pf_type, rank]) |
| 2707 | lookup_plug = plg_creator.create_plugin("lookup", pfc) |
| 2708 | plug_inputs = [input.trt_tensor, weight.trt_tensor] |
| 2709 | if per_token_scale is not None: |
| 2710 | plug_inputs.append(per_token_scale.trt_tensor) |
| 2711 | weight.trt_tensor.set_dynamic_range(-127, 127) |
| 2712 | layer = default_trtnet().add_plugin_v2(plug_inputs, lookup_plug) |
| 2713 | _add_plugin_info(layer, plg_creator, "lookup", pfc) |
| 2714 | return _create_tensor(layer.get_output(0), layer) |
| 2715 | |
| 2716 | |
| 2717 | def embedding(input: Tensor, |
no test coverage detected