MCPcopy Index your code
hub / github.com/tensorflow/tfjs-examples / generateText

Function generateText

lstm-text-generation/model.js:103–144  ·  view source on GitHub ↗
(
    model, textData, sentenceIndices, length, temperature,
    onTextGenerationChar)

Source from the content-addressed store, hash-verified

101 * @returns {string} The generated sentence.
102 */
103export async function generateText(
104 model, textData, sentenceIndices, length, temperature,
105 onTextGenerationChar) {
106 const sampleLen = model.inputs[0].shape[1];
107 const charSetSize = model.inputs[0].shape[2];
108
109 // Avoid overwriting the original input.
110 sentenceIndices = sentenceIndices.slice();
111
112 let generated = '';
113 while (generated.length < length) {
114 // Encode the current input sequence as a one-hot Tensor.
115 const inputBuffer =
116 new tf.TensorBuffer([1, sampleLen, charSetSize]);
117
118 // Make the one-hot encoding of the seeding sentence.
119 for (let i = 0; i < sampleLen; ++i) {
120 inputBuffer.set(1, 0, i, sentenceIndices[i]);
121 }
122 const input = inputBuffer.toTensor();
123
124 // Call model.predict() to get the probability values of the next
125 // character.
126 const output = model.predict(input);
127
128 // Sample randomly based on the probability values.
129 const winnerIndex = sample(tf.squeeze(output), temperature);
130 const winnerChar = textData.getFromCharSet(winnerIndex);
131 if (onTextGenerationChar != null) {
132 await onTextGenerationChar(winnerChar);
133 }
134
135 generated += winnerChar;
136 sentenceIndices = sentenceIndices.slice(1);
137 sentenceIndices.push(winnerIndex);
138
139 // Memory cleanups.
140 input.dispose();
141 output.dispose();
142 }
143 return generated;
144}
145
146/**
147 * Draw a sample based on probabilities.

Callers 3

mainFunction · 0.90
mainFunction · 0.90
model_test.jsFile · 0.90

Calls 4

sampleFunction · 0.85
onTextGenerationCharFunction · 0.85
getFromCharSetMethod · 0.80
predictMethod · 0.45

Tested by

no test coverage detected