(state, h, w)
| 345 | */ |
| 346 | |
| 347 | export function getStateTensor(state, h, w) { |
| 348 | if (!Array.isArray(state)) { |
| 349 | state = [state]; |
| 350 | } |
| 351 | const numExamples = state.length; |
| 352 | // TODO(cais): Maintain only a single buffer for efficiency. |
| 353 | const buffer = tf.buffer([numExamples, h, w, 2]); |
| 354 | |
| 355 | for (let n = 0; n < numExamples; ++n) { |
| 356 | if (state[n] == null) { |
| 357 | continue; |
| 358 | } |
| 359 | // Mark the snake. |
| 360 | state[n].s.forEach((yx, i) => { |
| 361 | buffer.set(i === 0 ? 2 : 1, n, yx[0], yx[1], 0); |
| 362 | }); |
| 363 | |
| 364 | // Mark the fruit(s). |
| 365 | state[n].f.forEach(yx => { |
| 366 | buffer.set(1, n, yx[0], yx[1], 1); |
| 367 | }); |
| 368 | } |
| 369 | return buffer.toTensor(); |
| 370 | } |
no outgoing calls
no test coverage detected