MCPcopy
hub / github.com/google-deepmind/acme / __init__

Method __init__

acme/agents/tf/mpo/agent.py:47–192  ·  view source on GitHub ↗

Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed

(
      self,
      environment_spec: specs.EnvironmentSpec,
      policy_network: snt.Module,
      critic_network: snt.Module,
      observation_network: types.TensorTransformation = tf.identity,
      discount: float = 0.99,
      batch_size: int = 256,
      prefetch_size: int = 4,
      target_policy_update_period: int = 100,
      target_critic_update_period: int = 100,
      min_replay_size: int = 1000,
      max_replay_size: int = 1000000,
      samples_per_insert: float = 32.0,
      policy_loss_module: Optional[snt.Module] = None,
      policy_optimizer: Optional[snt.Optimizer] = None,
      critic_optimizer: Optional[snt.Optimizer] = None,
      n_step: int = 5,
      num_samples: int = 20,
      clipping: bool = True,
      logger: Optional[loggers.Logger] = None,
      counter: Optional[counting.Counter] = None,
      checkpoint: bool = True,
      save_directory: str = '~/acme',
      replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
  )

Source from the content-addressed store, hash-verified

45 """
46
47 def __init__(
48 self,
49 environment_spec: specs.EnvironmentSpec,
50 policy_network: snt.Module,
51 critic_network: snt.Module,
52 observation_network: types.TensorTransformation = tf.identity,
53 discount: float = 0.99,
54 batch_size: int = 256,
55 prefetch_size: int = 4,
56 target_policy_update_period: int = 100,
57 target_critic_update_period: int = 100,
58 min_replay_size: int = 1000,
59 max_replay_size: int = 1000000,
60 samples_per_insert: float = 32.0,
61 policy_loss_module: Optional[snt.Module] = None,
62 policy_optimizer: Optional[snt.Optimizer] = None,
63 critic_optimizer: Optional[snt.Optimizer] = None,
64 n_step: int = 5,
65 num_samples: int = 20,
66 clipping: bool = True,
67 logger: Optional[loggers.Logger] = None,
68 counter: Optional[counting.Counter] = None,
69 checkpoint: bool = True,
70 save_directory: str = '~/acme',
71 replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE,
72 ):
73 """Initialize the agent.
74
75 Args:
76 environment_spec: description of the actions, observations, etc.
77 policy_network: the online (optimized) policy.
78 critic_network: the online critic.
79 observation_network: optional network to transform the observations before
80 they are fed into any network.
81 discount: discount to use for TD updates.
82 batch_size: batch size for updates.
83 prefetch_size: size to prefetch from replay.
84 target_policy_update_period: number of updates to perform before updating
85 the target policy network.
86 target_critic_update_period: number of updates to perform before updating
87 the target critic network.
88 min_replay_size: minimum replay size before updating.
89 max_replay_size: maximum replay size.
90 samples_per_insert: number of samples to take from replay for every insert
91 that is made.
92 policy_loss_module: configured MPO loss function for the policy
93 optimization; defaults to sensible values on the control suite. See
94 `acme/tf/losses/mpo.py` for more details.
95 policy_optimizer: optimizer to be used on the policy.
96 critic_optimizer: optimizer to be used on the critic.
97 n_step: number of steps to squash into a single transition.
98 num_samples: number of actions to sample when doing a Monte Carlo
99 integration with respect to the policy.
100 clipping: whether to clip gradients by global norm.
101 logger: logging object used to write to logs.
102 counter: counter object used to keep track of steps.
103 checkpoint: boolean indicating whether to checkpoint the learner.
104 save_directory: string indicating where the learner should save

Callers

nothing calls this directly

Calls 1

signatureMethod · 0.45

Tested by

no test coverage detected