(
labels: Tensor, logits: Tensor)
| 169 | * @param logits The logits. |
| 170 | */ |
| 171 | export function sigmoidCrossEntropyWithLogits( |
| 172 | labels: Tensor, logits: Tensor): Tensor { |
| 173 | if (!util.arraysEqual(labels.shape, logits.shape)) { |
| 174 | throw new ValueError( |
| 175 | `logits and labels must have the same shape, but got shapes ` + |
| 176 | `${JSON.stringify(labels.shape)} and ${JSON.stringify(logits.shape)}`); |
| 177 | } |
| 178 | return tidy(() => { |
| 179 | // The logistic loss formula from above is |
| 180 | // x - x * z + log(1 + exp(-x)) |
| 181 | // For x < 0, a more numerically stable formula is |
| 182 | // -x * z + log(1 + exp(x)) |
| 183 | // Note that these two expressions can be combined into the following: |
| 184 | // max(x, 0) - x * z + log(1 + exp(-abs(x))) |
| 185 | const reluLogits = tfc.relu(logits); |
| 186 | const negAbsLogits = tfc.neg(tfc.abs(logits)); |
| 187 | return tfc.add( |
| 188 | tfc.sub(reluLogits, tfc.mul(logits, labels)), |
| 189 | tfc.log1p(tfc.exp(negAbsLogits))); |
| 190 | }); |
| 191 | } |
| 192 | |
| 193 | export function binaryCrossentropy(yTrue: Tensor, yPred: Tensor): Tensor { |
| 194 | return tidy(() => { |
no test coverage detected
searching dependent graphs…