| 30 | import {generateText} from './model'; |
| 31 | |
| 32 | function parseArgs() { |
| 33 | const parser = argparse.ArgumentParser({ |
| 34 | description: 'Train an lstm-text-generation model.' |
| 35 | }); |
| 36 | parser.addArgument('textDatasetNameOrPath', { |
| 37 | type: 'string', |
| 38 | help: 'Name of the text dataset (one of ' + |
| 39 | Object.keys(TEXT_DATA_URLS).join(', ') + |
| 40 | ') or the path to a text file containing a custom dataset' |
| 41 | }); |
| 42 | parser.addArgument('modelJSONPath', { |
| 43 | type: 'string', |
| 44 | help: 'Path to the trained next-char prediction model saved on disk ' + |
| 45 | '(e.g., ./my-model/model.json)' |
| 46 | }); |
| 47 | parser.addArgument('--genLength', { |
| 48 | type: 'int', |
| 49 | defaultValue: 200, |
| 50 | help: 'Length of the text to generate.' |
| 51 | }); |
| 52 | parser.addArgument('--temperature', { |
| 53 | type: 'float', |
| 54 | defaultValue: 0.5, |
| 55 | help: 'Temperature value to use for text generation. Higher values ' + |
| 56 | 'lead to more random-looking generation results.' |
| 57 | }); |
| 58 | parser.addArgument('--gpu', { |
| 59 | action: 'storeTrue', |
| 60 | help: 'Use CUDA GPU for training.' |
| 61 | }); |
| 62 | parser.addArgument('--sampleStep', { |
| 63 | type: 'int', |
| 64 | defaultValue: 3, |
| 65 | help: 'Step length: how many characters to skip between one example ' + |
| 66 | 'extracted from the text data to the next.' |
| 67 | }); |
| 68 | |
| 69 | const args = parser.parseArgs(); |
| 70 | |
| 71 | const isDataset = TEXT_DATA_URLS[args.textDatasetNameOrPath]; |
| 72 | const isFile = fs.existsSync(args.textDatasetNameOrPath) |
| 73 | && fs.statSync(args.textDatasetNameOrPath).isFile(); |
| 74 | if (isDataset) { |
| 75 | args.textDatasetName = args.textDatasetNameOrPath; |
| 76 | delete args.textDatasetNameOrPath; |
| 77 | } else if (isFile) { |
| 78 | args.textDatasetPath = args.textDatasetNameOrPath; |
| 79 | delete args.textDatasetNameOrPath; |
| 80 | } else { |
| 81 | parser.error('Argument should be one of ' + |
| 82 | Object.keys(TEXT_DATA_URLS).join(', ') + |
| 83 | ' or the path to a dataset text file'); |
| 84 | } |
| 85 | return args; |
| 86 | } |
| 87 | |
| 88 | async function main() { |
| 89 | const args = parseArgs(); |