| 78 | self.size = min(self.size + 1, self.capacity) |
| 79 | |
| 80 | def _stack(self, idx): |
| 81 | # Gather frames[idx-stack+1 .. idx]; newest at last channel. |
| 82 | offsets = np.arange(self.stack) |
| 83 | gather = (idx[:, None] - (self.stack - 1) + offsets[None, :]) % self.capacity |
| 84 | out = self.frames[gather] |
| 85 | # Zero out frames sitting before an episode boundary inside the stack. |
| 86 | # dones at the (stack-1) older positions mark where a prior episode ended. |
| 87 | older = self.dones[gather[:, :-1]].astype(bool) |
| 88 | # Once we cross any done walking newest→oldest, everything older is invalid. |
| 89 | invalid = np.cumsum(older[:, ::-1], axis=1)[:, ::-1] > 0 |
| 90 | mask = np.concatenate([~invalid, np.ones((idx.shape[0], 1), dtype=bool)], axis=1) |
| 91 | return out * mask[:, :, None, None] |
| 92 | |
| 93 | def sample(self, batch_size, device): |
| 94 | # Reject indices whose stack would straddle the write head (stale frames). |