MCPcopy
hub / github.com/tensorflow/tfjs-examples / main

Function main

lstm-text-generation/train_node.js:118–178  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

116}
117
118async function main() {
119 const args = parseArgs();
120 if (args.gpu) {
121 console.log('Using GPU');
122 require('@tensorflow/tfjs-node-gpu');
123 } else {
124 console.log('Using CPU');
125 require('@tensorflow/tfjs-node');
126 }
127
128 // Create the text data object.
129 let localTextDataPath = args.textDatasetPath;
130 if (args.textDatasetName) {
131 const textDataURL = TEXT_DATA_URLS[args.textDatasetName].url;
132 localTextDataPath = path.join(os.tmpdir(), path.basename(textDataURL));
133 await maybeDownload(textDataURL, localTextDataPath);
134 }
135 const text = fs.readFileSync(localTextDataPath, {encoding: 'utf-8'});
136 const textData =
137 new TextData('text-data', text, args.sampleLen, args.sampleStep);
138
139 // Convert lstmLayerSize from string to number array before handing it
140 // to `createModel()`.
141 const lstmLayerSize = args.lstmLayerSize.indexOf(',') === -1 ?
142 Number.parseInt(args.lstmLayerSize) :
143 args.lstmLayerSize.split(',').map(x => Number.parseInt(x));
144
145 const model = createModel(
146 textData.sampleLen(), textData.charSetSize(), lstmLayerSize);
147 compileModel(model, args.learningRate);
148
149 // Get a seed text for display in the course of model training.
150 const [seed, seedIndices] = textData.getRandomSlice();
151 console.log(`Seed text:\n"${seed}"\n`);
152
153 const DISPLAY_TEMPERATURES = [0, 0.25, 0.5, 0.75];
154
155 let epochCount = 0;
156 await fitModel(
157 model, textData, args.epochs, args.examplesPerEpoch, args.batchSize,
158 args.validationSplit, {
159 onTrainBegin: async () => {
160 epochCount++;
161 console.log(`Epoch ${epochCount} of ${args.epochs}:`);
162 },
163 onTrainEnd: async () => {
164 DISPLAY_TEMPERATURES.forEach(async temperature => {
165 const generated = await generateText(
166 model, textData, seedIndices, args.displayLength, temperature);
167 console.log(
168 `Generated text (temperature=${temperature}):\n` +
169 `"${generated}"\n`);
170 });
171 }
172 });
173
174 if (args.savePath != null && args.savePath.length > 0) {
175 await model.save(`file://${args.savePath}`);

Callers 1

train_node.jsFile · 0.70

Calls 9

sampleLenMethod · 0.95
charSetSizeMethod · 0.95
getRandomSliceMethod · 0.95
maybeDownloadFunction · 0.90
createModelFunction · 0.90
compileModelFunction · 0.90
fitModelFunction · 0.90
generateTextFunction · 0.90
parseArgsFunction · 0.70

Tested by

no test coverage detected