MCPcopy
hub / github.com/tensorflow/tfjs / standardizeWeights

Function standardizeWeights

tfjs-layers/src/engine/training_utils.ts:125–179  ·  view source on GitHub ↗
(
    y: Tensor, sampleWeight?: Tensor, classWeight?: ClassWeight,
    sampleWeightMode?: 'temporal')

Source from the content-addressed store, hash-verified

123 * matches that of `y`.
124 */
125export async function standardizeWeights(
126 y: Tensor, sampleWeight?: Tensor, classWeight?: ClassWeight,
127 sampleWeightMode?: 'temporal'): Promise<Tensor> {
128 if (sampleWeight != null || sampleWeightMode != null) {
129 // TODO(cais): Once 'temporal' mode is implemented, document it in the doc
130 // string.
131 throw new Error('Support sampleWeight is not implemented yet');
132 }
133
134 if (classWeight != null) {
135 // Apply class weights per sample.
136 const yClasses: Tensor1D = tidy(() => {
137 if (y.shape.length === 1) {
138 // Assume class indices.
139 return clone(y) as Tensor1D;
140 } else if (y.shape.length === 2) {
141 if (y.shape[1] > 1) {
142 // Assume one-hot encoding of classes.
143 const axis = 1;
144 return argMax(y, axis);
145 } else if (y.shape[1] === 1) {
146 // Class index.
147 return reshape(y, [y.shape[0]]);
148 } else {
149 throw new Error(
150 `Encountered unexpected last-dimension size (${y.shape[1]}) ` +
151 `during handling of class weights. The size is expected to be ` +
152 `>= 1.`);
153 }
154 } else {
155 throw new Error(
156 `Unexpected rank of target (y) tensor (${y.rank}) during ` +
157 `handling of class weights. The rank is expected to be 1 or 2.`);
158 }
159 });
160
161 const yClassIndices = Array.from(await yClasses.data());
162 dispose(yClasses);
163 const classSampleWeight: number[] = [];
164 yClassIndices.forEach(classIndex => {
165 if (classWeight[classIndex] == null) {
166 throw new Error(
167 `classWeight must contain all classes in the training data. ` +
168 `The class ${classIndex} exists in the data but not in ` +
169 `classWeight`);
170 } else {
171 classSampleWeight.push(classWeight[classIndex]);
172 }
173 });
174
175 return tensor1d(classSampleWeight, 'float32');
176 } else {
177 return null;
178 }
179}
180
181/**
182 * Apply per-sample weights on the loss values from a number of samples.

Callers 3

standardizeUserDataMethod · 0.90
fitDatasetFunction · 0.90

Calls 7

tidyFunction · 0.90
disposeFunction · 0.90
tensor1dFunction · 0.90
dataMethod · 0.65
argMaxFunction · 0.50
reshapeFunction · 0.50
pushMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…