* * @param {Number[]|*} [rawInput] * @param {Boolean} [isSampleI] * @param {Number} temperature * @returns {*}
(rawInput = [], isSampleI = false, temperature = 1)
| 307 | * @returns {*} |
| 308 | */ |
| 309 | run(rawInput = [], isSampleI = false, temperature = 1) { |
| 310 | const maxPredictionLength = this.maxPredictionLength + rawInput.length + (this.dataFormatter ? this.dataFormatter.specialIndexes.length : 0); |
| 311 | if (!this.isRunnable) return null; |
| 312 | const input = this.formatDataIn(rawInput); |
| 313 | const model = this.model; |
| 314 | const output = []; |
| 315 | let i = 0; |
| 316 | while (true) { |
| 317 | let previousIndex = (i === 0 |
| 318 | ? 0 |
| 319 | : i < input.length |
| 320 | ? input[i - 1] + 1 |
| 321 | : output[i - 1]) |
| 322 | ; |
| 323 | while (model.equations.length <= i) { |
| 324 | this.bindEquation(); |
| 325 | } |
| 326 | let equation = model.equations[i]; |
| 327 | // sample predicted letter |
| 328 | let outputMatrix = equation.run(previousIndex); |
| 329 | let logProbabilities = new Matrix(model.output.rows, model.output.columns); |
| 330 | copy(logProbabilities, outputMatrix); |
| 331 | if (temperature !== 1 && isSampleI) { |
| 332 | /** |
| 333 | * scale log probabilities by temperature and re-normalize |
| 334 | * if temperature is high, logProbabilities will go towards zero |
| 335 | * and the softmax outputs will be more diffuse. if temperature is |
| 336 | * very low, the softmax outputs will be more peaky |
| 337 | */ |
| 338 | for (let j = 0, max = logProbabilities.weights.length; j < max; j++) { |
| 339 | logProbabilities.weights[j] /= temperature; |
| 340 | } |
| 341 | } |
| 342 | |
| 343 | let probs = softmax(logProbabilities); |
| 344 | let nextIndex = (isSampleI ? sampleI(probs) : maxI(probs)); |
| 345 | |
| 346 | i++; |
| 347 | if (nextIndex === 0) { |
| 348 | // END token predicted, break out |
| 349 | break; |
| 350 | } |
| 351 | if (i >= maxPredictionLength) { |
| 352 | // something is wrong |
| 353 | break; |
| 354 | } |
| 355 | |
| 356 | output.push(nextIndex); |
| 357 | } |
| 358 | |
| 359 | /** |
| 360 | * we slice the input length here, not because output contains it, but it will be erroneous as we are sending the |
| 361 | * network what is contained in input, so the data is essentially guessed by the network what could be next, till it |
| 362 | * locks in on a value. |
| 363 | * Kind of like this, values are from input: |
| 364 | * 0 -> 4 (or in English: "beginning on input" -> "I have no idea? I'll guess what they want next!") |
| 365 | * 2 -> 2 (oh how interesting, I've narrowed down values...) |
| 366 | * 1 -> 9 (oh how interesting, I've now know what the values are...) |
no test coverage detected