()
| 116 | } |
| 117 | |
| 118 | async function main() { |
| 119 | const args = parseArgs(); |
| 120 | if (args.gpu) { |
| 121 | console.log('Using GPU'); |
| 122 | require('@tensorflow/tfjs-node-gpu'); |
| 123 | } else { |
| 124 | console.log('Using CPU'); |
| 125 | require('@tensorflow/tfjs-node'); |
| 126 | } |
| 127 | |
| 128 | // Create the text data object. |
| 129 | let localTextDataPath = args.textDatasetPath; |
| 130 | if (args.textDatasetName) { |
| 131 | const textDataURL = TEXT_DATA_URLS[args.textDatasetName].url; |
| 132 | localTextDataPath = path.join(os.tmpdir(), path.basename(textDataURL)); |
| 133 | await maybeDownload(textDataURL, localTextDataPath); |
| 134 | } |
| 135 | const text = fs.readFileSync(localTextDataPath, {encoding: 'utf-8'}); |
| 136 | const textData = |
| 137 | new TextData('text-data', text, args.sampleLen, args.sampleStep); |
| 138 | |
| 139 | // Convert lstmLayerSize from string to number array before handing it |
| 140 | // to `createModel()`. |
| 141 | const lstmLayerSize = args.lstmLayerSize.indexOf(',') === -1 ? |
| 142 | Number.parseInt(args.lstmLayerSize) : |
| 143 | args.lstmLayerSize.split(',').map(x => Number.parseInt(x)); |
| 144 | |
| 145 | const model = createModel( |
| 146 | textData.sampleLen(), textData.charSetSize(), lstmLayerSize); |
| 147 | compileModel(model, args.learningRate); |
| 148 | |
| 149 | // Get a seed text for display in the course of model training. |
| 150 | const [seed, seedIndices] = textData.getRandomSlice(); |
| 151 | console.log(`Seed text:\n"${seed}"\n`); |
| 152 | |
| 153 | const DISPLAY_TEMPERATURES = [0, 0.25, 0.5, 0.75]; |
| 154 | |
| 155 | let epochCount = 0; |
| 156 | await fitModel( |
| 157 | model, textData, args.epochs, args.examplesPerEpoch, args.batchSize, |
| 158 | args.validationSplit, { |
| 159 | onTrainBegin: async () => { |
| 160 | epochCount++; |
| 161 | console.log(`Epoch ${epochCount} of ${args.epochs}:`); |
| 162 | }, |
| 163 | onTrainEnd: async () => { |
| 164 | DISPLAY_TEMPERATURES.forEach(async temperature => { |
| 165 | const generated = await generateText( |
| 166 | model, textData, seedIndices, args.displayLength, temperature); |
| 167 | console.log( |
| 168 | `Generated text (temperature=${temperature}):\n` + |
| 169 | `"${generated}"\n`); |
| 170 | }); |
| 171 | } |
| 172 | }); |
| 173 | |
| 174 | if (args.savePath != null && args.savePath.length > 0) { |
| 175 | await model.save(`file://${args.savePath}`); |
no test coverage detected