Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed
(
self,
environment_spec: specs.EnvironmentSpec,
policy_network: snt.Module,
critic_network: snt.Module,
observation_network: types.TensorTransformation = tf.identity,
discount: float = 0.99,
batch_size: int = 256,
prefetch_size: int = 4,
target_update_period: int = 100,
policy_optimizer: Optional[snt.Optimizer] = None,
critic_optimizer: Optional[snt.Optimizer] = None,
min_replay_size: int = 1000,
max_replay_size: int = 1000000,
samples_per_insert: float = 32.0,
n_step: int = 5,
sigma: float = 0.3,
clipping: bool = True,
replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE,
counter: Optional[counting.Counter] = None,
logger: Optional[loggers.Logger] = None,
checkpoint: bool = True,
)
| 250 | """ |
| 251 | |
| 252 | def __init__( |
| 253 | self, |
| 254 | environment_spec: specs.EnvironmentSpec, |
| 255 | policy_network: snt.Module, |
| 256 | critic_network: snt.Module, |
| 257 | observation_network: types.TensorTransformation = tf.identity, |
| 258 | discount: float = 0.99, |
| 259 | batch_size: int = 256, |
| 260 | prefetch_size: int = 4, |
| 261 | target_update_period: int = 100, |
| 262 | policy_optimizer: Optional[snt.Optimizer] = None, |
| 263 | critic_optimizer: Optional[snt.Optimizer] = None, |
| 264 | min_replay_size: int = 1000, |
| 265 | max_replay_size: int = 1000000, |
| 266 | samples_per_insert: float = 32.0, |
| 267 | n_step: int = 5, |
| 268 | sigma: float = 0.3, |
| 269 | clipping: bool = True, |
| 270 | replay_table_name: str = reverb_adders.DEFAULT_PRIORITY_TABLE, |
| 271 | counter: Optional[counting.Counter] = None, |
| 272 | logger: Optional[loggers.Logger] = None, |
| 273 | checkpoint: bool = True, |
| 274 | ): |
| 275 | """Initialize the agent. |
| 276 | |
| 277 | Args: |
| 278 | environment_spec: description of the actions, observations, etc. |
| 279 | policy_network: the online (optimized) policy. |
| 280 | critic_network: the online critic. |
| 281 | observation_network: optional network to transform the observations before |
| 282 | they are fed into any network. |
| 283 | discount: discount to use for TD updates. |
| 284 | batch_size: batch size for updates. |
| 285 | prefetch_size: size to prefetch from replay. |
| 286 | target_update_period: number of learner steps to perform before updating |
| 287 | the target networks. |
| 288 | policy_optimizer: optimizer for the policy network updates. |
| 289 | critic_optimizer: optimizer for the critic network updates. |
| 290 | min_replay_size: minimum replay size before updating. |
| 291 | max_replay_size: maximum replay size. |
| 292 | samples_per_insert: number of samples to take from replay for every insert |
| 293 | that is made. |
| 294 | n_step: number of steps to squash into a single transition. |
| 295 | sigma: standard deviation of zero-mean, Gaussian exploration noise. |
| 296 | clipping: whether to clip gradients by global norm. |
| 297 | replay_table_name: string indicating what name to give the replay table. |
| 298 | counter: counter object used to keep track of steps. |
| 299 | logger: logger object to be used by learner. |
| 300 | checkpoint: boolean indicating whether to checkpoint the learner. |
| 301 | """ |
| 302 | # Create the Builder object which will internally create agent components. |
| 303 | builder = D4PGBuilder( |
| 304 | # TODO(mwhoffman): pass the config dataclass in directly. |
| 305 | # TODO(mwhoffman): use the limiter rather than the workaround below. |
| 306 | # Right now this modifies min_replay_size and samples_per_insert so that |
| 307 | # they are not controlled by a limiter and are instead handled by the |
| 308 | # Agent base class (the above TODO directly references this behavior). |
| 309 | D4PGConfig( |
nothing calls this directly
no test coverage detected