(probs, temperature)
| 155 | * range of `[0, charSetSize - 1]`. |
| 156 | */ |
| 157 | export function sample(probs, temperature) { |
| 158 | return tf.tidy(() => { |
| 159 | const logits = tf.div(tf.log(probs), Math.max(temperature, 1e-6)); |
| 160 | const isNormalized = false; |
| 161 | // `logits` is for a multinomial distribution, scaled by the temperature. |
| 162 | // We randomly draw a sample from the distribution. |
| 163 | return tf.multinomial(logits, 1, null, isNormalized).dataSync()[0]; |
| 164 | }); |
| 165 | } |
no outgoing calls
no test coverage detected