MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / Network

Class Network

tensorrt_llm/network.py:124–896  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

122
123
124class Network(object):
125
126 def __init__(self, **kwargs):
127 # intentionally use **kwargs, user should never call this ctor directly
128 # use Builder.create_network() instead
129
130 # Holds the removed layers and disable them in graph rewriting and other phases.
131 # This is a hacky way since INetwork python API doesn't provide a way to remove a layer.
132 # TODO: remove this when TensorRT provides a better way to remove a layer
133 self._removed_layers: Set[str] = set()
134
135 self.is_graph_altered = False
136
137 from .graph_rewriting import FLayerInfoMemo
138 self.flayer_memo = FLayerInfoMemo() # holds the functional metadata
139 self._parameter_tensors = {} # holds the parameter tensors
140
141 def _init(self, trt_network):
142 self._trt_network = trt_network
143 self._inputs = {}
144 self._named_parameters = None
145 # layer precision of a given scope, this is used together with precision(dtype) context manager
146 self._dtype = None
147 self._name_generator = _UniqueNameGenerator()
148 self._plugin_config = PluginConfig()
149 self._module_call_stack = _TrtLlmModuleCallStack()
150 self._registered_ndarrays = []
151 self._strongly_typed = trt.INetworkDefinition.get_flag(
152 self._trt_network, trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
153 self._unfilled_weights: Dict[str, Tuple[np.array, np.array]] = {}
154
155 return self
156
157 def _register_unfilled_weights(self, layer_name: str, weights: np.array,
158 values: np.array):
159 self._unfilled_weights[layer_name] = (weights, values)
160
161 def _fill_weights(self):
162 from tensorrt_llm.parameter import Parameter
163
164 for layer_name in list(self._unfilled_weights.keys()):
165 weights, values = self._unfilled_weights.pop(layer_name)
166 self.register_ndarray(weights)
167 if values is not None:
168 np.copyto(weights, values, casting='no')
169 else:
170 Parameter.xavier_init(weights)
171
172 @property
173 def parameter_tensors(self):
174 return self._parameter_tensors
175
176 def get_parameter_tensor(self, param):
177 return self.parameter_tensors.get(param, None)
178
179 def set_parameter_tensor(self, param, tensor):
180 assert param not in self.parameter_tensors
181 self.parameter_tensors[param] = tensor

Callers 1

create_networkMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected