MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / run_engine

Function run_engine

examples/python_plugin/run_lookup.py:14–61  ·  view source on GitHub ↗
(dtype)

Source from the content-addressed store, hash-verified

12if __name__ == "__main__":
13
14 def run_engine(dtype):
15 output_dir = Path('tmp') / torch_dtype_to_str(dtype)
16
17 engine_path = output_dir / "lookup.engine"
18
19 with engine_path.open('rb') as f:
20 session = Session.from_serialized_engine(f.read())
21
22 # meta data
23 batch_size = 10
24 vocab_size = 1000
25 n_embed = 1024
26
27 # test data
28 ## input index
29 index_shape = (batch_size, )
30 index_data = torch.randint(0,
31 vocab_size,
32 index_shape,
33 dtype=torch.int32).cuda()
34 weight_data = torch.rand(vocab_size, n_embed, dtype=dtype).cuda()
35
36 inputs = {"x": index_data, "y": weight_data}
37
38 output_info = session.infer_shapes([
39 TensorInfo(name, torch_dtype_to_trt(tensor.dtype), tensor.shape)
40 for name, tensor in inputs.items()
41 ])
42 logger.debug(f'output info {output_info}')
43 outputs = {
44 t.name:
45 torch.empty(tuple(t.shape),
46 dtype=trt_dtype_to_torch(t.dtype),
47 device='cuda')
48 for t in output_info
49 }
50
51 stream = torch.cuda.Stream()
52 ok = session.run(inputs=inputs,
53 outputs=outputs,
54 stream=stream.cuda_stream)
55 assert ok, 'Engine execution failed'
56
57 embedding = torch.nn.Embedding.from_pretrained(weight_data)
58 torch_out = embedding(index_data).to(torch.float32)
59 trt_out = outputs['output']
60
61 torch.testing.assert_close(trt_out, torch_out)
62
63 run_engine(torch.bfloat16)
64 run_engine(torch.float16)

Callers 5

tllmMethod · 0.85
compare_contextMethod · 0.85
compare_generationMethod · 0.85
run_lookup.pyFile · 0.85

Calls 12

torch_dtype_to_strFunction · 0.90
TensorInfoClass · 0.90
torch_dtype_to_trtFunction · 0.90
trt_dtype_to_torchFunction · 0.90
embeddingFunction · 0.85
infer_shapesMethod · 0.80
debugMethod · 0.45
emptyMethod · 0.45
runMethod · 0.45
from_pretrainedMethod · 0.45
toMethod · 0.45

Tested by 4

tllmMethod · 0.68
compare_contextMethod · 0.68
compare_generationMethod · 0.68