* Load IMDB data features from a local file. * * @param {string} filePath Data file on local filesystem. * @param {string} numWords Number of words in the vocabulary. Word indices * that exceed this limit will be marked as `OOV_INDEX`. * @param {string} maxLen Length of each sequence. Longer
(filePath, numWords, maxLen, multihot = false)
| 47 | * shape `[numExamples, numWords]` and dtype `float32`. |
| 48 | */ |
| 49 | function loadFeatures(filePath, numWords, maxLen, multihot = false) { |
| 50 | const buffer = fs.readFileSync(filePath); |
| 51 | const numBytes = buffer.byteLength; |
| 52 | |
| 53 | let sequences = []; |
| 54 | let seq = []; |
| 55 | let index = 0; |
| 56 | |
| 57 | while (index < numBytes) { |
| 58 | const value = buffer.readInt32LE(index); |
| 59 | if (value === 1) { |
| 60 | // A new sequence has started. |
| 61 | if (index > 0) { |
| 62 | sequences.push(seq); |
| 63 | } |
| 64 | seq = []; |
| 65 | } else { |
| 66 | // Sequence continues. |
| 67 | seq.push(value >= numWords ? OOV_INDEX : value); |
| 68 | } |
| 69 | index += 4; |
| 70 | } |
| 71 | if (seq.length > 0) { |
| 72 | sequences.push(seq); |
| 73 | } |
| 74 | |
| 75 | // Get some sequence length stats. |
| 76 | let minLength = Infinity; |
| 77 | let maxLength = -Infinity; |
| 78 | sequences.forEach(seq => { |
| 79 | const length = seq.length; |
| 80 | if (length < minLength) { |
| 81 | minLength = length; |
| 82 | } |
| 83 | if (length > maxLength) { |
| 84 | maxLength = length; |
| 85 | } |
| 86 | }); |
| 87 | console.log(`Sequence length: min = ${minLength}; max = ${maxLength}`); |
| 88 | |
| 89 | if (multihot) { |
| 90 | // If requested by the arg, encode the sequences as multi-hot |
| 91 | // vectors. |
| 92 | const buffer = tf.buffer([sequences.length, numWords]); |
| 93 | sequences.forEach((seq, i) => { |
| 94 | seq.forEach(wordIndex => { |
| 95 | if (wordIndex !== OOV_INDEX) { |
| 96 | buffer.set(1, i, wordIndex); |
| 97 | } |
| 98 | }); |
| 99 | }); |
| 100 | return buffer.toTensor(); |
| 101 | } else { |
| 102 | const paddedSequences = |
| 103 | padSequences(sequences, maxLen, 'pre', 'pre'); |
| 104 | return tf.tensor2d( |
| 105 | paddedSequences, [paddedSequences.length, maxLen], 'int32'); |
| 106 | } |
no test coverage detected