()
| 86 | } |
| 87 | |
| 88 | async function main() { |
| 89 | const args = parseArgs(); |
| 90 | |
| 91 | if (args.gpu) { |
| 92 | console.log('Using GPU'); |
| 93 | require('@tensorflow/tfjs-node-gpu'); |
| 94 | } else { |
| 95 | console.log('Using CPU'); |
| 96 | require('@tensorflow/tfjs-node'); |
| 97 | } |
| 98 | |
| 99 | // Load the model. |
| 100 | const loadModel = tf.loadModel || tf.loadLayersModel; |
| 101 | const model = await loadModel(`file://${args.modelJSONPath}`); |
| 102 | |
| 103 | const sampleLen = model.inputs[0].shape[1]; |
| 104 | |
| 105 | // Create the text data object. |
| 106 | let localTextDataPath = args.textDatasetPath; |
| 107 | if (args.textDatasetName) { |
| 108 | const textDataURL = TEXT_DATA_URLS[args.textDatasetName].url; |
| 109 | localTextDataPath = path.join(os.tmpdir(), path.basename(textDataURL)); |
| 110 | await maybeDownload(textDataURL, localTextDataPath); |
| 111 | } |
| 112 | const text = fs.readFileSync(localTextDataPath, {encoding: 'utf-8'}); |
| 113 | const textData = new TextData('text-data', text, sampleLen, args.sampleStep); |
| 114 | |
| 115 | // Get a seed text from the text data object. |
| 116 | const [seed, seedIndices] = textData.getRandomSlice(); |
| 117 | |
| 118 | console.log(`Seed text:\n"${seed}"\n`); |
| 119 | |
| 120 | const generated = await generateText( |
| 121 | model, textData, seedIndices, args.genLength, args.temperature); |
| 122 | |
| 123 | console.log(`Generated text:\n"${generated}"\n`); |
| 124 | } |
| 125 | |
| 126 | main(); |
no test coverage detected