* Generates a random input for input definition. * * ```js * const input = generateInput(inputDefs); * * console.log(`Generated input: ${Object.values(input)}`); * console.log(`Prediction for the generated input: ${prediction}`); * ``` * * @param inputDefs The input definition that is used
(inputDefs, isForGraphModel = false)
| 78 | * @param isForGraphModel flag for whether to generate inputs for GraphModel |
| 79 | */ |
| 80 | function generateInputFromDef(inputDefs, isForGraphModel = false) { |
| 81 | if (inputDefs == null) { |
| 82 | throw new Error('The inputDef cannot be found.'); |
| 83 | } |
| 84 | |
| 85 | const tensorArray = []; |
| 86 | try { |
| 87 | inputDefs.forEach((inputDef, inputDefIndex) => { |
| 88 | const inputShape = inputDef.shape; |
| 89 | |
| 90 | // Construct the input tensor. |
| 91 | let inputTensor; |
| 92 | if (inputDef.dtype === 'float32' || inputDef.dtype === 'int32') { |
| 93 | // We assume a bell curve normal distribution. In this case, |
| 94 | // we use below approximation: |
| 95 | // mean ~= (min + max) / 2 |
| 96 | // std ~= (max - min) / 4 |
| 97 | // Note: for std, our approximation is based on the fact that |
| 98 | // 95% of the data is within the range of 2 stds above and |
| 99 | // below the mean. So 95% of the data falls in the range of |
| 100 | // 4 stds. |
| 101 | const min = inputDef.range[0]; |
| 102 | const max = inputDef.range[1]; |
| 103 | const mean = (min + max) / 2; |
| 104 | const std = (max - min) / 4; |
| 105 | generatedRaw = tf.randomNormal(inputShape, mean, std, inputDef.dtype); |
| 106 | // We clip the value to be within [min, max], because 5% of |
| 107 | // the data generated maybe outside of [min, max]. |
| 108 | inputTensor = tf.clipByValue(generatedRaw, min, max); |
| 109 | generatedRaw.dispose(); |
| 110 | } else if (inputDef.dtype === 'string') { |
| 111 | size = tf.util.sizeFromShape(inputDef.shape); |
| 112 | data = [...Array(size)].map( |
| 113 | () => Math.random().toString(36).substring(2, 7)); |
| 114 | inputTensor = tf.tensor(data, inputShape, inputDef.dtype); |
| 115 | } else { |
| 116 | throw new Error( |
| 117 | `The ${inputDef.dtype} dtype of '${inputDef.name}' input ` + |
| 118 | `at model.inputs[${inputDefIndex}] is not supported.`); |
| 119 | } |
| 120 | tensorArray.push(inputTensor); |
| 121 | }); |
| 122 | |
| 123 | // Return tensor map for tf.GraphModel. |
| 124 | if (isForGraphModel) { |
| 125 | const tensorMap = inputDefs.reduce((map, inputDef, i) => { |
| 126 | map[inputDef.name] = tensorArray[i]; |
| 127 | return map; |
| 128 | }, {}); |
| 129 | return tensorMap; |
| 130 | } |
| 131 | |
| 132 | return tensorArray; |
| 133 | } catch (e) { |
| 134 | // Dispose all input tensors when the input construction is failed. |
| 135 | tensorArray.forEach(tensor => { |
| 136 | if (tensor instanceof tf.Tensor) { |
| 137 | tensor.dispose(); |
no test coverage detected
searching dependent graphs…