MCPcopy
hub / github.com/google-deepmind/acme / __init__

Method __init__

acme/agents/tf/ddpg/agent.py:46–174  ·  view source on GitHub ↗

Initialize the agent. Args: environment_spec: description of the actions, observations, etc. policy_network: the online (optimized) policy. critic_network: the online critic. observation_network: optional network to transform the observations before they are fed

(self,
               environment_spec: specs.EnvironmentSpec,
               policy_network: snt.Module,
               critic_network: snt.Module,
               observation_network: types.TensorTransformation = tf.identity,
               discount: float = 0.99,
               batch_size: int = 256,
               prefetch_size: int = 4,
               target_update_period: int = 100,
               min_replay_size: int = 1000,
               max_replay_size: int = 1000000,
               samples_per_insert: float = 32.0,
               n_step: int = 5,
               sigma: float = 0.3,
               clipping: bool = True,
               logger: Optional[loggers.Logger] = None,
               counter: Optional[counting.Counter] = None,
               checkpoint: bool = True,
               replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE)

Source from the content-addressed store, hash-verified

44 """
45
46 def __init__(self,
47 environment_spec: specs.EnvironmentSpec,
48 policy_network: snt.Module,
49 critic_network: snt.Module,
50 observation_network: types.TensorTransformation = tf.identity,
51 discount: float = 0.99,
52 batch_size: int = 256,
53 prefetch_size: int = 4,
54 target_update_period: int = 100,
55 min_replay_size: int = 1000,
56 max_replay_size: int = 1000000,
57 samples_per_insert: float = 32.0,
58 n_step: int = 5,
59 sigma: float = 0.3,
60 clipping: bool = True,
61 logger: Optional[loggers.Logger] = None,
62 counter: Optional[counting.Counter] = None,
63 checkpoint: bool = True,
64 replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE):
65 """Initialize the agent.
66
67 Args:
68 environment_spec: description of the actions, observations, etc.
69 policy_network: the online (optimized) policy.
70 critic_network: the online critic.
71 observation_network: optional network to transform the observations before
72 they are fed into any network.
73 discount: discount to use for TD updates.
74 batch_size: batch size for updates.
75 prefetch_size: size to prefetch from replay.
76 target_update_period: number of learner steps to perform before updating
77 the target networks.
78 min_replay_size: minimum replay size before updating.
79 max_replay_size: maximum replay size.
80 samples_per_insert: number of samples to take from replay for every insert
81 that is made.
82 n_step: number of steps to squash into a single transition.
83 sigma: standard deviation of zero-mean, Gaussian exploration noise.
84 clipping: whether to clip gradients by global norm.
85 logger: logger object to be used by learner.
86 counter: counter object used to keep track of steps.
87 checkpoint: boolean indicating whether to checkpoint the learner.
88 replay_table_name: string indicating what name to give the replay table.
89 """
90 # Create a replay server to add data to. This uses no limiter behavior in
91 # order to allow the Agent interface to handle it.
92 replay_table = reverb.Table(
93 name=replay_table_name,
94 sampler=reverb.selectors.Uniform(),
95 remover=reverb.selectors.Fifo(),
96 max_size=max_replay_size,
97 rate_limiter=reverb.rate_limiters.MinSize(1),
98 signature=adders.NStepTransitionAdder.signature(environment_spec))
99 self._server = reverb.Server([replay_table], port=None)
100
101 # The adder is used to insert observations into replay.
102 address = f'localhost:{self._server.port}'
103 adder = adders.NStepTransitionAdder(

Callers

nothing calls this directly

Calls 1

signatureMethod · 0.45

Tested by

no test coverage detected