(key: jnp.ndarray)
| 83 | self._key = key |
| 84 | |
| 85 | def sample(key: jnp.ndarray) -> Tuple[Any, jnp.ndarray]: |
| 86 | key, key_randint = jax.random.split(key) |
| 87 | indices = jax.random.randint(key_randint, (batch_size,), minval=0, |
| 88 | maxval=self._dataset_size) |
| 89 | demo_transitions = jax.tree_map(lambda d: jnp.take(d, indices, axis=0), |
| 90 | self._jax_dataset) |
| 91 | return demo_transitions, key |
| 92 | self._sample = jax.jit(sample) |
| 93 | |
| 94 | def __next__(self) -> Any: |