MCPcopy
hub / github.com/tensorflow/tfjs-examples / main

Function main

lstm-text-generation/gen_node.js:88–124  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

86}
87
88async function main() {
89 const args = parseArgs();
90
91 if (args.gpu) {
92 console.log('Using GPU');
93 require('@tensorflow/tfjs-node-gpu');
94 } else {
95 console.log('Using CPU');
96 require('@tensorflow/tfjs-node');
97 }
98
99 // Load the model.
100 const loadModel = tf.loadModel || tf.loadLayersModel;
101 const model = await loadModel(`file://${args.modelJSONPath}`);
102
103 const sampleLen = model.inputs[0].shape[1];
104
105 // Create the text data object.
106 let localTextDataPath = args.textDatasetPath;
107 if (args.textDatasetName) {
108 const textDataURL = TEXT_DATA_URLS[args.textDatasetName].url;
109 localTextDataPath = path.join(os.tmpdir(), path.basename(textDataURL));
110 await maybeDownload(textDataURL, localTextDataPath);
111 }
112 const text = fs.readFileSync(localTextDataPath, {encoding: 'utf-8'});
113 const textData = new TextData('text-data', text, sampleLen, args.sampleStep);
114
115 // Get a seed text from the text data object.
116 const [seed, seedIndices] = textData.getRandomSlice();
117
118 console.log(`Seed text:\n"${seed}"\n`);
119
120 const generated = await generateText(
121 model, textData, seedIndices, args.genLength, args.temperature);
122
123 console.log(`Generated text:\n"${generated}"\n`);
124}
125
126main();

Callers 1

gen_node.jsFile · 0.70

Calls 5

getRandomSliceMethod · 0.95
maybeDownloadFunction · 0.90
generateTextFunction · 0.90
parseArgsFunction · 0.70
loadModelFunction · 0.50

Tested by

no test coverage detected