Session is a managed TensorRT runtime.
| 81 | |
| 82 | |
| 83 | class Session(object): |
| 84 | ''' Session is a managed TensorRT runtime. ''' |
| 85 | |
| 86 | def __init__(self, **kwargs): |
| 87 | # use Session.from_serialized_engine to create a session |
| 88 | pass |
| 89 | |
| 90 | def _init(self, engine_buffer=None): |
| 91 | ''' |
| 92 | @brief: Setup TensorRT engines and context from a serialized engine file |
| 93 | @param engine_buffer: a buffer holds the serialized TRT engine |
| 94 | ''' |
| 95 | self._runtime = trt.Runtime(logger.trt_logger) |
| 96 | if engine_buffer is not None: |
| 97 | self._engine = self.runtime.deserialize_cuda_engine(engine_buffer) |
| 98 | |
| 99 | self._context = None |
| 100 | if not self.engine.streamable_weights_size: |
| 101 | self.__prepare_execution_contexts() |
| 102 | return self |
| 103 | |
| 104 | def __prepare_execution_contexts(self): |
| 105 | self._context = self.engine.create_execution_context() |
| 106 | assert self._context is not None, "Failed to create an execution context!" |
| 107 | with _scoped_stream() as stream: |
| 108 | self._context.set_optimization_profile_async(0, stream) |
| 109 | |
| 110 | @staticmethod |
| 111 | def from_serialized_engine(engine) -> Session: |
| 112 | ''' |
| 113 | @brief: Create a session from a serialized engine |
| 114 | @param engine: a serialized engine |
| 115 | @return: a Session object |
| 116 | ''' |
| 117 | session = Session() |
| 118 | return session._init(engine) |
| 119 | |
| 120 | @staticmethod |
| 121 | def from_engine(engine) -> Session: |
| 122 | ''' |
| 123 | @brief: Create a session from an existing ICudaEngine engine |
| 124 | @param engine: an ICudaEngine |
| 125 | @return: a Session object |
| 126 | ''' |
| 127 | session = Session() |
| 128 | session.engine = engine |
| 129 | return session._init() |
| 130 | |
| 131 | @property |
| 132 | def runtime(self) -> trt.Runtime: |
| 133 | return self._runtime |
| 134 | |
| 135 | @property |
| 136 | def engine(self) -> trt.ICudaEngine: |
| 137 | return self._engine |
| 138 | |
| 139 | @engine.setter |
| 140 | def engine(self, engine: trt.ICudaEngine): |
no outgoing calls
no test coverage detected