MCPcopy
hub / github.com/tensorflow/tfjs / generateInputFromDef

Function generateInputFromDef

e2e/benchmarks/benchmark_util.js:80–142  ·  view source on GitHub ↗

* Generates a random input for input definition. * * ```js * const input = generateInput(inputDefs); * * console.log(`Generated input: ${Object.values(input)}`); * console.log(`Prediction for the generated input: ${prediction}`); * ``` * * @param inputDefs The input definition that is used

(inputDefs, isForGraphModel = false)

Source from the content-addressed store, hash-verified

78 * @param isForGraphModel flag for whether to generate inputs for GraphModel
79 */
80function generateInputFromDef(inputDefs, isForGraphModel = false) {
81 if (inputDefs == null) {
82 throw new Error('The inputDef cannot be found.');
83 }
84
85 const tensorArray = [];
86 try {
87 inputDefs.forEach((inputDef, inputDefIndex) => {
88 const inputShape = inputDef.shape;
89
90 // Construct the input tensor.
91 let inputTensor;
92 if (inputDef.dtype === 'float32' || inputDef.dtype === 'int32') {
93 // We assume a bell curve normal distribution. In this case,
94 // we use below approximation:
95 // mean ~= (min + max) / 2
96 // std ~= (max - min) / 4
97 // Note: for std, our approximation is based on the fact that
98 // 95% of the data is within the range of 2 stds above and
99 // below the mean. So 95% of the data falls in the range of
100 // 4 stds.
101 const min = inputDef.range[0];
102 const max = inputDef.range[1];
103 const mean = (min + max) / 2;
104 const std = (max - min) / 4;
105 generatedRaw = tf.randomNormal(inputShape, mean, std, inputDef.dtype);
106 // We clip the value to be within [min, max], because 5% of
107 // the data generated maybe outside of [min, max].
108 inputTensor = tf.clipByValue(generatedRaw, min, max);
109 generatedRaw.dispose();
110 } else if (inputDef.dtype === 'string') {
111 size = tf.util.sizeFromShape(inputDef.shape);
112 data = [...Array(size)].map(
113 () => Math.random().toString(36).substring(2, 7));
114 inputTensor = tf.tensor(data, inputShape, inputDef.dtype);
115 } else {
116 throw new Error(
117 `The ${inputDef.dtype} dtype of '${inputDef.name}' input ` +
118 `at model.inputs[${inputDefIndex}] is not supported.`);
119 }
120 tensorArray.push(inputTensor);
121 });
122
123 // Return tensor map for tf.GraphModel.
124 if (isForGraphModel) {
125 const tensorMap = inputDefs.reduce((map, inputDef, i) => {
126 map[inputDef.name] = tensorArray[i];
127 return map;
128 }, {});
129 return tensorMap;
130 }
131
132 return tensorArray;
133 } catch (e) {
134 // Dispose all input tensors when the input construction is failed.
135 tensorArray.forEach(tensor => {
136 if (tensor instanceof tf.Tensor) {
137 tensor.dispose();

Callers 3

model_config.jsFile · 0.85
generateInputFunction · 0.85

Calls 4

clipByValueMethod · 0.80
toStringMethod · 0.80
disposeMethod · 0.45
pushMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…