MCPcopy
hub / github.com/tensorlayer/TensorLayer / ReplayBuffer

Class ReplayBuffer

examples/reinforcement_learning/tutorial_TD3.py:94–127  ·  view source on GitHub ↗

a ring buffer for storing transitions and sampling for training :state: (state_dim,) :action: (action_dim,) :reward: (,), scalar :next_state: (state_dim,) :done: (,), scalar (0 and 1) or bool (True and False)

Source from the content-addressed store, hash-verified

92
93
94class ReplayBuffer:
95 """
96 a ring buffer for storing transitions and sampling for training
97 :state: (state_dim,)
98 :action: (action_dim,)
99 :reward: (,), scalar
100 :next_state: (state_dim,)
101 :done: (,), scalar (0 and 1) or bool (True and False)
102 """
103
104 def __init__(self, capacity):
105 self.capacity = capacity
106 self.buffer = []
107 self.position = 0
108
109 def push(self, state, action, reward, next_state, done):
110 if len(self.buffer) < self.capacity:
111 self.buffer.append(None)
112 self.buffer[self.position] = (state, action, reward, next_state, done)
113 self.position = int((self.position + 1) % self.capacity) # as a ring buffer
114
115 def sample(self, batch_size):
116 batch = random.sample(self.buffer, batch_size)
117 state, action, reward, next_state, done = map(np.stack, zip(*batch)) # stack for each element
118 """
119 the * serves as unpack: sum(a,b) <=> batch=(a,b), sum(*batch) ;
120 zip: a=[1,2], b=[2,3], zip(a,b) => [(1, 2), (2, 3)] ;
121 the map serves as mapping the function on each list element: map(square, [2,3]) => [4,9] ;
122 np.stack((1,2)) => array([1, 2])
123 """
124 return state, action, reward, next_state, done
125
126 def __len__(self):
127 return len(self.buffer)
128
129
130class QNetwork(Model):

Callers 1

tutorial_TD3.pyFile · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…