| 12 | if __name__ == "__main__": |
| 13 | |
| 14 | def run_engine(dtype): |
| 15 | output_dir = Path('tmp') / torch_dtype_to_str(dtype) |
| 16 | |
| 17 | engine_path = output_dir / "lookup.engine" |
| 18 | |
| 19 | with engine_path.open('rb') as f: |
| 20 | session = Session.from_serialized_engine(f.read()) |
| 21 | |
| 22 | # meta data |
| 23 | batch_size = 10 |
| 24 | vocab_size = 1000 |
| 25 | n_embed = 1024 |
| 26 | |
| 27 | # test data |
| 28 | ## input index |
| 29 | index_shape = (batch_size, ) |
| 30 | index_data = torch.randint(0, |
| 31 | vocab_size, |
| 32 | index_shape, |
| 33 | dtype=torch.int32).cuda() |
| 34 | weight_data = torch.rand(vocab_size, n_embed, dtype=dtype).cuda() |
| 35 | |
| 36 | inputs = {"x": index_data, "y": weight_data} |
| 37 | |
| 38 | output_info = session.infer_shapes([ |
| 39 | TensorInfo(name, torch_dtype_to_trt(tensor.dtype), tensor.shape) |
| 40 | for name, tensor in inputs.items() |
| 41 | ]) |
| 42 | logger.debug(f'output info {output_info}') |
| 43 | outputs = { |
| 44 | t.name: |
| 45 | torch.empty(tuple(t.shape), |
| 46 | dtype=trt_dtype_to_torch(t.dtype), |
| 47 | device='cuda') |
| 48 | for t in output_info |
| 49 | } |
| 50 | |
| 51 | stream = torch.cuda.Stream() |
| 52 | ok = session.run(inputs=inputs, |
| 53 | outputs=outputs, |
| 54 | stream=stream.cuda_stream) |
| 55 | assert ok, 'Engine execution failed' |
| 56 | |
| 57 | embedding = torch.nn.Embedding.from_pretrained(weight_data) |
| 58 | torch_out = embedding(index_data).to(torch.float32) |
| 59 | trt_out = outputs['output'] |
| 60 | |
| 61 | torch.testing.assert_close(trt_out, torch_out) |
| 62 | |
| 63 | run_engine(torch.bfloat16) |
| 64 | run_engine(torch.float16) |