MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / _lookup_plugin

Function _lookup_plugin

tensorrt_llm/functional.py:2670–2714  ·  view source on GitHub ↗

Add an operation to perform lookup in a tensor. That operation performs the lookup needed by embedding layers. Given a 'weight' tensor of shape [rows, cols], it produces a tensor of shape [inputs.size(0), cols] where the ith row corresponds to the input[i] row in the weight ten

(input: Tensor, weight: Tensor, rank: int,
                   per_token_scale: Tensor)

Source from the content-addressed store, hash-verified

2668
2669
2670def _lookup_plugin(input: Tensor, weight: Tensor, rank: int,
2671 per_token_scale: Tensor) -> Tensor:
2672 '''
2673 Add an operation to perform lookup in a tensor.
2674
2675 That operation performs the lookup needed by embedding layers. Given a
2676 'weight' tensor of shape [rows, cols], it produces a tensor of shape
2677 [inputs.size(0), cols] where the ith row corresponds to the input[i] row in
2678 the weight tensor.
2679
2680 It inserts a IPluginV2Layer.
2681
2682 Parameters:
2683 input : Tensor
2684 The input tensor contains the indices to perform the lookup.
2685
2686 weight : Tensor
2687 The table to gather from.
2688
2689 rank : int
2690 The mpi rank.
2691
2692 Returns:
2693 The output tensor of the lookup layer.
2694 '''
2695 plg_creator = trt.get_plugin_registry().get_plugin_creator(
2696 'Lookup', '1', TRT_LLM_PLUGIN_NAMESPACE)
2697 assert plg_creator is not None
2698
2699 p_dtype = per_token_scale.dtype
2700 pf_type = trt.PluginField("type_id", np.array([int(p_dtype)], np.int32),
2701 trt.PluginFieldType.INT32)
2702
2703 rank = trt.PluginField("rank", np.array([int(rank)], np.int32),
2704 trt.PluginFieldType.INT32)
2705
2706 pfc = trt.PluginFieldCollection([pf_type, rank])
2707 lookup_plug = plg_creator.create_plugin("lookup", pfc)
2708 plug_inputs = [input.trt_tensor, weight.trt_tensor]
2709 if per_token_scale is not None:
2710 plug_inputs.append(per_token_scale.trt_tensor)
2711 weight.trt_tensor.set_dynamic_range(-127, 127)
2712 layer = default_trtnet().add_plugin_v2(plug_inputs, lookup_plug)
2713 _add_plugin_info(layer, plg_creator, "lookup", pfc)
2714 return _create_tensor(layer.get_output(0), layer)
2715
2716
2717def embedding(input: Tensor,

Callers 1

embeddingFunction · 0.85

Calls 6

default_trtnetFunction · 0.85
_add_plugin_infoFunction · 0.85
_create_tensorFunction · 0.85
create_pluginMethod · 0.80
appendMethod · 0.45
get_outputMethod · 0.45

Tested by

no test coverage detected