| 123 | * matches that of `y`. |
| 124 | */ |
| 125 | export async function standardizeWeights( |
| 126 | y: Tensor, sampleWeight?: Tensor, classWeight?: ClassWeight, |
| 127 | sampleWeightMode?: 'temporal'): Promise<Tensor> { |
| 128 | if (sampleWeight != null || sampleWeightMode != null) { |
| 129 | // TODO(cais): Once 'temporal' mode is implemented, document it in the doc |
| 130 | // string. |
| 131 | throw new Error('Support sampleWeight is not implemented yet'); |
| 132 | } |
| 133 | |
| 134 | if (classWeight != null) { |
| 135 | // Apply class weights per sample. |
| 136 | const yClasses: Tensor1D = tidy(() => { |
| 137 | if (y.shape.length === 1) { |
| 138 | // Assume class indices. |
| 139 | return clone(y) as Tensor1D; |
| 140 | } else if (y.shape.length === 2) { |
| 141 | if (y.shape[1] > 1) { |
| 142 | // Assume one-hot encoding of classes. |
| 143 | const axis = 1; |
| 144 | return argMax(y, axis); |
| 145 | } else if (y.shape[1] === 1) { |
| 146 | // Class index. |
| 147 | return reshape(y, [y.shape[0]]); |
| 148 | } else { |
| 149 | throw new Error( |
| 150 | `Encountered unexpected last-dimension size (${y.shape[1]}) ` + |
| 151 | `during handling of class weights. The size is expected to be ` + |
| 152 | `>= 1.`); |
| 153 | } |
| 154 | } else { |
| 155 | throw new Error( |
| 156 | `Unexpected rank of target (y) tensor (${y.rank}) during ` + |
| 157 | `handling of class weights. The rank is expected to be 1 or 2.`); |
| 158 | } |
| 159 | }); |
| 160 | |
| 161 | const yClassIndices = Array.from(await yClasses.data()); |
| 162 | dispose(yClasses); |
| 163 | const classSampleWeight: number[] = []; |
| 164 | yClassIndices.forEach(classIndex => { |
| 165 | if (classWeight[classIndex] == null) { |
| 166 | throw new Error( |
| 167 | `classWeight must contain all classes in the training data. ` + |
| 168 | `The class ${classIndex} exists in the data but not in ` + |
| 169 | `classWeight`); |
| 170 | } else { |
| 171 | classSampleWeight.push(classWeight[classIndex]); |
| 172 | } |
| 173 | }); |
| 174 | |
| 175 | return tensor1d(classSampleWeight, 'float32'); |
| 176 | } else { |
| 177 | return null; |
| 178 | } |
| 179 | } |
| 180 | |
| 181 | /** |
| 182 | * Apply per-sample weights on the loss values from a number of samples. |