MCPcopy
hub / github.com/BrainJS/brain.js / train

Method train

src/recurrent/rnn.js:384–425  ·  view source on GitHub ↗

* * @param {Object[]|String[]} data an array of objects: `{input: 'string', output: 'string'}` or an array of strings * @param {Object} [options] * @returns {{error: number, iterations: number}}

(data, options = {})

Source from the content-addressed store, hash-verified

382 * @returns {{error: number, iterations: number}}
383 */
384 train(data, options = {}) {
385 options = Object.assign({}, this.constructor.trainDefaults, options);
386 let iterations = options.iterations;
387 let errorThresh = options.errorThresh;
388 let log = options.log === true ? console.log : options.log;
389 let logPeriod = options.logPeriod;
390 let learningRate = options.learningRate || this.learningRate;
391 let callback = options.callback;
392 let callbackPeriod = options.callbackPeriod;
393 let error = Infinity;
394 let i;
395
396 if (this.hasOwnProperty('setupData')) {
397 data = this.setupData(data);
398 }
399
400 if (!this.model) {
401 this.initialize();
402 }
403
404 for (i = 0; i < iterations && error > errorThresh; i++) {
405 let sum = 0;
406 for (let j = 0; j < data.length; j++) {
407 let err = this.trainPattern(data[j], learningRate);
408 sum += err;
409 }
410 error = sum / data.length;
411
412 if (isNaN(error)) throw new Error('network error rate is unexpected NaN, check network configurations and try again');
413 if (log && (i % logPeriod === 0)) {
414 log(`iterations: ${ i }, training error: ${ error }`);
415 }
416 if (callback && (i % callbackPeriod === 0)) {
417 callback({ error: error, iterations: i });
418 }
419 }
420
421 return {
422 error: error,
423 iterations: i
424 };
425 }
426
427 /**
428 *

Callers 15

browser.min.jsFile · 0.45
browser.jsFile · 0.45
learn-math.tsFile · 0.45
predict-numbers.tsFile · 0.45
cross-validate.tsFile · 0.45
childrens-book.tsFile · 0.45
browser.test.jsFile · 0.45
iris.jsFile · 0.45
lstm.jsFile · 0.45
gru.jsFile · 0.45
rnn-time-step.jsFile · 0.45

Calls 3

initializeMethod · 0.95
trainPatternMethod · 0.95
logFunction · 0.85

Tested by

no test coverage detected