(self, batch_size, device)
| 91 | return out * mask[:, :, None, None] |
| 92 | |
| 93 | def sample(self, batch_size, device): |
| 94 | # Reject indices whose stack would straddle the write head (stale frames). |
| 95 | while True: |
| 96 | if self.size < self.capacity: |
| 97 | if self.size < self.stack + 2: |
| 98 | raise RuntimeError("buffer too small to sample yet") |
| 99 | idx = np.random.randint(self.stack - 1, self.size - 1, size=batch_size) |
| 100 | break |
| 101 | idx = np.random.randint(0, self.capacity, size=batch_size) |
| 102 | dist = (self.idx - 1 - idx) % self.capacity |
| 103 | if np.all(dist >= self.stack): |
| 104 | break |
| 105 | states = self._stack(idx) |
| 106 | next_states = self._stack((idx + 1) % self.capacity) |
| 107 | return ( |
| 108 | torch.as_tensor(states, device=device), |
| 109 | torch.as_tensor(self.actions[idx], device=device), |
| 110 | torch.as_tensor(self.rewards[idx], device=device), |
| 111 | torch.as_tensor(next_states, device=device), |
| 112 | torch.as_tensor(self.dones[idx], device=device), |
| 113 | ) |
| 114 | |
| 115 | |
| 116 | def epsilon(frame): |
no test coverage detected