* Executes the inference for given input tensors. * @param inputs Tensor map for the model inputs, keyed by the input node * names. * @param outputs Optional. output node name from the Tensorflow model, if * no outputs are specified, the default outputs of the model would be used. * Y
(inputs: NamedTensorMap, outputs?: string[])
| 229 | * outputs array. |
| 230 | */ |
| 231 | execute(inputs: NamedTensorMap, outputs?: string[]): Tensor[] { |
| 232 | // Dispose any tensors from a prior run to avoid leaking them. |
| 233 | this.disposeIntermediateTensors(); |
| 234 | inputs = this.mapInputs(inputs); |
| 235 | const names = Object.keys(inputs).sort(); |
| 236 | this.checkInputs(inputs); |
| 237 | this.checkInputShapeAndType(inputs); |
| 238 | outputs = this.mapOutputs(outputs); |
| 239 | this.checkOutputs(outputs); |
| 240 | const inputNodes = |
| 241 | names.map(name => this.graph.nodes[parseNodeName(name)[0]]); |
| 242 | const outputNodeNames = outputs.map(name => parseNodeName(name)[0]); |
| 243 | const outputNodeNameSet = new Set(outputNodeNames); |
| 244 | let outputNodes = outputNodeNames.map(name => this.graph.nodes[name]); |
| 245 | // If no outputs are specified, then use the default outputs of the model. |
| 246 | if (outputNodes.length === 0) { |
| 247 | outputNodes = this._outputs; |
| 248 | } |
| 249 | |
| 250 | const compilationKey = this.getCompilationKey(inputNodes, outputNodes); |
| 251 | |
| 252 | // Do nothing if the compiled graph cache contains the input. |
| 253 | let compilation = this.compiledMap.get(compilationKey); |
| 254 | if (compilation == null) { |
| 255 | compilation = this.compile(inputs, outputNodes); |
| 256 | this.compiledMap.set(compilationKey, compilation); |
| 257 | } |
| 258 | |
| 259 | // Keep tensors if KEEP_INTERMEDIATE_TENSORS is on. |
| 260 | try { |
| 261 | this.keepIntermediateTensors = env().getBool('KEEP_INTERMEDIATE_TENSORS'); |
| 262 | } catch (e) { |
| 263 | this.keepIntermediateTensors = false; |
| 264 | console.warn(e.message); |
| 265 | } |
| 266 | const tensorArrayMap: TensorArrayMap = {}; |
| 267 | const tensorListMap: TensorListMap = {}; |
| 268 | |
| 269 | return tidy(() => { |
| 270 | const context = new ExecutionContext( |
| 271 | this.weightMap, tensorArrayMap, tensorListMap, |
| 272 | this.functionExecutorMap, this.parseNodeNameCache); |
| 273 | const tensorsMap: NamedTensorsMap = {...this.weightMap}; |
| 274 | if (this.keepIntermediateTensors) { |
| 275 | this.clonedTensorsMap = this.cloneTensorMap(this.weightMap); |
| 276 | } |
| 277 | |
| 278 | Object.keys(inputs).forEach(name => { |
| 279 | const [nodeName, index] = parseNodeName(name, context); |
| 280 | const tensors: Tensor[] = []; |
| 281 | tensors[index] = inputs[name]; |
| 282 | tensorsMap[nodeName] = tensors; |
| 283 | if (this.keepIntermediateTensors) { |
| 284 | this.clonedTensorsMap[nodeName] = this.cloneTensorList(tensors); |
| 285 | } |
| 286 | }); |
| 287 | |
| 288 | const tensorsToKeep = this.getFrozenTensorIds(tensorsMap); |
nothing calls this directly
no test coverage detected