Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed
(
self,
environment_spec: specs.EnvironmentSpec,
policy_network: snt.Module,
critic_network: snt.Module,
observation_network: types.TensorTransformation = tf.identity,
discount: float = 0.99,
batch_size: int = 256,
prefetch_size: int = 4,
target_policy_update_period: int = 100,
target_critic_update_period: int = 100,
min_replay_size: int = 1000,
max_replay_size: int = 1000000,
samples_per_insert: float = 32.0,
policy_loss_module: Optional[snt.Module] = None,
policy_optimizer: Optional[snt.Optimizer] = None,
critic_optimizer: Optional[snt.Optimizer] = None,
n_step: int = 5,
num_samples: int = 20,
clipping: bool = True,
logger: Optional[loggers.Logger] = None,
counter: Optional[counting.Counter] = None,
checkpoint: bool = True,
save_directory: str = '~/acme',
replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
)
| 45 | """ |
| 46 | |
| 47 | def __init__( |
| 48 | self, |
| 49 | environment_spec: specs.EnvironmentSpec, |
| 50 | policy_network: snt.Module, |
| 51 | critic_network: snt.Module, |
| 52 | observation_network: types.TensorTransformation = tf.identity, |
| 53 | discount: float = 0.99, |
| 54 | batch_size: int = 256, |
| 55 | prefetch_size: int = 4, |
| 56 | target_policy_update_period: int = 100, |
| 57 | target_critic_update_period: int = 100, |
| 58 | min_replay_size: int = 1000, |
| 59 | max_replay_size: int = 1000000, |
| 60 | samples_per_insert: float = 32.0, |
| 61 | policy_loss_module: Optional[snt.Module] = None, |
| 62 | policy_optimizer: Optional[snt.Optimizer] = None, |
| 63 | critic_optimizer: Optional[snt.Optimizer] = None, |
| 64 | n_step: int = 5, |
| 65 | num_samples: int = 20, |
| 66 | clipping: bool = True, |
| 67 | logger: Optional[loggers.Logger] = None, |
| 68 | counter: Optional[counting.Counter] = None, |
| 69 | checkpoint: bool = True, |
| 70 | save_directory: str = '~/acme', |
| 71 | replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, |
| 72 | ): |
| 73 | """Initialize the agent. |
| 74 | |
| 75 | Args: |
| 76 | environment_spec: description of the actions, observations, etc. |
| 77 | policy_network: the online (optimized) policy. |
| 78 | critic_network: the online critic. |
| 79 | observation_network: optional network to transform the observations before |
| 80 | they are fed into any network. |
| 81 | discount: discount to use for TD updates. |
| 82 | batch_size: batch size for updates. |
| 83 | prefetch_size: size to prefetch from replay. |
| 84 | target_policy_update_period: number of updates to perform before updating |
| 85 | the target policy network. |
| 86 | target_critic_update_period: number of updates to perform before updating |
| 87 | the target critic network. |
| 88 | min_replay_size: minimum replay size before updating. |
| 89 | max_replay_size: maximum replay size. |
| 90 | samples_per_insert: number of samples to take from replay for every insert |
| 91 | that is made. |
| 92 | policy_loss_module: configured MPO loss function for the policy |
| 93 | optimization; defaults to sensible values on the control suite. See |
| 94 | `acme/tf/losses/mpo.py` for more details. |
| 95 | policy_optimizer: optimizer to be used on the policy. |
| 96 | critic_optimizer: optimizer to be used on the critic. |
| 97 | n_step: number of steps to squash into a single transition. |
| 98 | num_samples: number of actions to sample when doing a Monte Carlo |
| 99 | integration with respect to the policy. |
| 100 | clipping: whether to clip gradients by global norm. |
| 101 | logger: logging object used to write to logs. |
| 102 | counter: counter object used to keep track of steps. |
| 103 | checkpoint: boolean indicating whether to checkpoint the learner. |
| 104 | save_directory: string indicating where the learner should save |