| 29 | import {createModel, compileModel, fitModel, generateText} from './model'; |
| 30 | |
| 31 | function parseArgs() { |
| 32 | const parser = argparse.ArgumentParser({ |
| 33 | description: 'Train an lstm-text-generation model.' |
| 34 | }); |
| 35 | parser.addArgument('textDatasetNameOrPath', { |
| 36 | type: 'string', |
| 37 | help: 'Name of the text dataset (one of ' + |
| 38 | Object.keys(TEXT_DATA_URLS).join(', ') + |
| 39 | ') or the path to a text file containing a custom dataset' |
| 40 | }); |
| 41 | parser.addArgument('--gpu', { |
| 42 | action: 'storeTrue', |
| 43 | help: 'Use CUDA GPU for training.' |
| 44 | }); |
| 45 | parser.addArgument('--sampleLen', { |
| 46 | type: 'int', |
| 47 | defaultValue: 60, |
| 48 | help: 'Sample length: Length of each input sequence to the model, in ' + |
| 49 | 'number of characters.' |
| 50 | }); |
| 51 | parser.addArgument('--sampleStep', { |
| 52 | type: 'int', |
| 53 | defaultValue: 3, |
| 54 | help: 'Step length: how many characters to skip between one example ' + |
| 55 | 'extracted from the text data to the next.' |
| 56 | }); |
| 57 | parser.addArgument('--learningRate', { |
| 58 | type: 'float', |
| 59 | defaultValue: 1e-2, |
| 60 | help: 'Learning rate to be used during training' |
| 61 | }); |
| 62 | parser.addArgument('--epochs', { |
| 63 | type: 'int', |
| 64 | defaultValue: 150, |
| 65 | help: 'Number of training epochs' |
| 66 | }); |
| 67 | parser.addArgument('--examplesPerEpoch', { |
| 68 | type: 'int', |
| 69 | defaultValue: 10000, |
| 70 | help: 'Number of examples to sample from the text in each training epoch.' |
| 71 | }); |
| 72 | parser.addArgument('--batchSize', { |
| 73 | type: 'int', |
| 74 | defaultValue: 128, |
| 75 | help: 'Batch size for training.' |
| 76 | }); |
| 77 | parser.addArgument('--validationSplit', { |
| 78 | type: 'float', |
| 79 | defaultValue: 0.0625, |
| 80 | help: 'Validation split for training.' |
| 81 | }); |
| 82 | parser.addArgument('--displayLength', { |
| 83 | type: 'int', |
| 84 | defaultValue: 120, |
| 85 | help: 'Length of the sampled text to display after each epoch of training.' |
| 86 | }); |
| 87 | parser.addArgument('--savePath', { |
| 88 | type: 'string', |