(
model, textData, sentenceIndices, length, temperature,
onTextGenerationChar)
| 101 | * @returns {string} The generated sentence. |
| 102 | */ |
| 103 | export async function generateText( |
| 104 | model, textData, sentenceIndices, length, temperature, |
| 105 | onTextGenerationChar) { |
| 106 | const sampleLen = model.inputs[0].shape[1]; |
| 107 | const charSetSize = model.inputs[0].shape[2]; |
| 108 | |
| 109 | // Avoid overwriting the original input. |
| 110 | sentenceIndices = sentenceIndices.slice(); |
| 111 | |
| 112 | let generated = ''; |
| 113 | while (generated.length < length) { |
| 114 | // Encode the current input sequence as a one-hot Tensor. |
| 115 | const inputBuffer = |
| 116 | new tf.TensorBuffer([1, sampleLen, charSetSize]); |
| 117 | |
| 118 | // Make the one-hot encoding of the seeding sentence. |
| 119 | for (let i = 0; i < sampleLen; ++i) { |
| 120 | inputBuffer.set(1, 0, i, sentenceIndices[i]); |
| 121 | } |
| 122 | const input = inputBuffer.toTensor(); |
| 123 | |
| 124 | // Call model.predict() to get the probability values of the next |
| 125 | // character. |
| 126 | const output = model.predict(input); |
| 127 | |
| 128 | // Sample randomly based on the probability values. |
| 129 | const winnerIndex = sample(tf.squeeze(output), temperature); |
| 130 | const winnerChar = textData.getFromCharSet(winnerIndex); |
| 131 | if (onTextGenerationChar != null) { |
| 132 | await onTextGenerationChar(winnerChar); |
| 133 | } |
| 134 | |
| 135 | generated += winnerChar; |
| 136 | sentenceIndices = sentenceIndices.slice(1); |
| 137 | sentenceIndices.push(winnerIndex); |
| 138 | |
| 139 | // Memory cleanups. |
| 140 | input.dispose(); |
| 141 | output.dispose(); |
| 142 | } |
| 143 | return generated; |
| 144 | } |
| 145 | |
| 146 | /** |
| 147 | * Draw a sample based on probabilities. |
no test coverage detected