Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed
(self,
environment_spec: specs.EnvironmentSpec,
policy_network: snt.Module,
critic_network: snt.Module,
observation_network: types.TensorTransformation = tf.identity,
discount: float = 0.99,
batch_size: int = 256,
prefetch_size: int = 4,
target_update_period: int = 100,
min_replay_size: int = 1000,
max_replay_size: int = 1000000,
samples_per_insert: float = 32.0,
n_step: int = 5,
sigma: float = 0.3,
clipping: bool = True,
logger: Optional[loggers.Logger] = None,
counter: Optional[counting.Counter] = None,
checkpoint: bool = True,
replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE)
| 44 | """ |
| 45 | |
| 46 | def __init__(self, |
| 47 | environment_spec: specs.EnvironmentSpec, |
| 48 | policy_network: snt.Module, |
| 49 | critic_network: snt.Module, |
| 50 | observation_network: types.TensorTransformation = tf.identity, |
| 51 | discount: float = 0.99, |
| 52 | batch_size: int = 256, |
| 53 | prefetch_size: int = 4, |
| 54 | target_update_period: int = 100, |
| 55 | min_replay_size: int = 1000, |
| 56 | max_replay_size: int = 1000000, |
| 57 | samples_per_insert: float = 32.0, |
| 58 | n_step: int = 5, |
| 59 | sigma: float = 0.3, |
| 60 | clipping: bool = True, |
| 61 | logger: Optional[loggers.Logger] = None, |
| 62 | counter: Optional[counting.Counter] = None, |
| 63 | checkpoint: bool = True, |
| 64 | replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE): |
| 65 | """Initialize the agent. |
| 66 | |
| 67 | Args: |
| 68 | environment_spec: description of the actions, observations, etc. |
| 69 | policy_network: the online (optimized) policy. |
| 70 | critic_network: the online critic. |
| 71 | observation_network: optional network to transform the observations before |
| 72 | they are fed into any network. |
| 73 | discount: discount to use for TD updates. |
| 74 | batch_size: batch size for updates. |
| 75 | prefetch_size: size to prefetch from replay. |
| 76 | target_update_period: number of learner steps to perform before updating |
| 77 | the target networks. |
| 78 | min_replay_size: minimum replay size before updating. |
| 79 | max_replay_size: maximum replay size. |
| 80 | samples_per_insert: number of samples to take from replay for every insert |
| 81 | that is made. |
| 82 | n_step: number of steps to squash into a single transition. |
| 83 | sigma: standard deviation of zero-mean, Gaussian exploration noise. |
| 84 | clipping: whether to clip gradients by global norm. |
| 85 | logger: logger object to be used by learner. |
| 86 | counter: counter object used to keep track of steps. |
| 87 | checkpoint: boolean indicating whether to checkpoint the learner. |
| 88 | replay_table_name: string indicating what name to give the replay table. |
| 89 | """ |
| 90 | # Create a replay server to add data to. This uses no limiter behavior in |
| 91 | # order to allow the Agent interface to handle it. |
| 92 | replay_table = reverb.Table( |
| 93 | name=replay_table_name, |
| 94 | sampler=reverb.selectors.Uniform(), |
| 95 | remover=reverb.selectors.Fifo(), |
| 96 | max_size=max_replay_size, |
| 97 | rate_limiter=reverb.rate_limiters.MinSize(1), |
| 98 | signature=adders.NStepTransitionAdder.signature(environment_spec)) |
| 99 | self._server = reverb.Server([replay_table], port=None) |
| 100 | |
| 101 | # The adder is used to insert observations into replay. |
| 102 | address = f'localhost:{self._server.port}' |
| 103 | adder = adders.NStepTransitionAdder( |