a ring buffer for storing transitions and sampling for training :state: (state_dim,) :action: (action_dim,) :reward: (,), scalar :next_state: (state_dim,) :done: (,), scalar (0 and 1) or bool (True and False)
| 80 | |
| 81 | |
| 82 | class ReplayBuffer: |
| 83 | """ |
| 84 | a ring buffer for storing transitions and sampling for training |
| 85 | :state: (state_dim,) |
| 86 | :action: (action_dim,) |
| 87 | :reward: (,), scalar |
| 88 | :next_state: (state_dim,) |
| 89 | :done: (,), scalar (0 and 1) or bool (True and False) |
| 90 | """ |
| 91 | |
| 92 | def __init__(self, capacity): |
| 93 | self.capacity = capacity |
| 94 | self.buffer = [] |
| 95 | self.position = 0 |
| 96 | |
| 97 | def push(self, state, action, reward, next_state, done): |
| 98 | if len(self.buffer) < self.capacity: |
| 99 | self.buffer.append(None) |
| 100 | self.buffer[self.position] = (state, action, reward, next_state, done) |
| 101 | self.position = int((self.position + 1) % self.capacity) # as a ring buffer |
| 102 | |
| 103 | def sample(self, BATCH_SIZE): |
| 104 | batch = random.sample(self.buffer, BATCH_SIZE) |
| 105 | state, action, reward, next_state, done = map(np.stack, zip(*batch)) # stack for each element |
| 106 | """ |
| 107 | the * serves as unpack: sum(a,b) <=> batch=(a,b), sum(*batch) ; |
| 108 | zip: a=[1,2], b=[2,3], zip(a,b) => [(1, 2), (2, 3)] ; |
| 109 | the map serves as mapping the function on each list element: map(square, [2,3]) => [4,9] ; |
| 110 | np.stack((1,2)) => array([1, 2]) |
| 111 | """ |
| 112 | return state, action, reward, next_state, done |
| 113 | |
| 114 | def __len__(self): |
| 115 | return len(self.buffer) |
| 116 | |
| 117 | |
| 118 | class SoftQNetwork(Model): |
no outgoing calls
no test coverage detected
searching dependent graphs…