a ring buffer for storing transitions and sampling for training :state: (state_dim,) :action: (action_dim,) :reward: (,), scalar :next_state: (state_dim,) :done: (,), scalar (0 and 1) or bool (True and False)
| 92 | |
| 93 | |
| 94 | class ReplayBuffer: |
| 95 | """ |
| 96 | a ring buffer for storing transitions and sampling for training |
| 97 | :state: (state_dim,) |
| 98 | :action: (action_dim,) |
| 99 | :reward: (,), scalar |
| 100 | :next_state: (state_dim,) |
| 101 | :done: (,), scalar (0 and 1) or bool (True and False) |
| 102 | """ |
| 103 | |
| 104 | def __init__(self, capacity): |
| 105 | self.capacity = capacity |
| 106 | self.buffer = [] |
| 107 | self.position = 0 |
| 108 | |
| 109 | def push(self, state, action, reward, next_state, done): |
| 110 | if len(self.buffer) < self.capacity: |
| 111 | self.buffer.append(None) |
| 112 | self.buffer[self.position] = (state, action, reward, next_state, done) |
| 113 | self.position = int((self.position + 1) % self.capacity) # as a ring buffer |
| 114 | |
| 115 | def sample(self, batch_size): |
| 116 | batch = random.sample(self.buffer, batch_size) |
| 117 | state, action, reward, next_state, done = map(np.stack, zip(*batch)) # stack for each element |
| 118 | """ |
| 119 | the * serves as unpack: sum(a,b) <=> batch=(a,b), sum(*batch) ; |
| 120 | zip: a=[1,2], b=[2,3], zip(a,b) => [(1, 2), (2, 3)] ; |
| 121 | the map serves as mapping the function on each list element: map(square, [2,3]) => [4,9] ; |
| 122 | np.stack((1,2)) => array([1, 2]) |
| 123 | """ |
| 124 | return state, action, reward, next_state, done |
| 125 | |
| 126 | def __len__(self): |
| 127 | return len(self.buffer) |
| 128 | |
| 129 | |
| 130 | class QNetwork(Model): |
no outgoing calls
no test coverage detected
searching dependent graphs…