MCPcopy
hub / github.com/google-deepmind/acme / __init__

Method __init__

acme/agents/tf/r2d2/agent.py:46–148  ·  view source on GitHub ↗
(
      self,
      environment_spec: specs.EnvironmentSpec,
      network: snt.RNNCore,
      burn_in_length: int,
      trace_length: int,
      replay_period: int,
      counter: Optional[counting.Counter] = None,
      logger: Optional[loggers.Logger] = None,
      discount: float = 0.99,
      batch_size: int = 32,
      prefetch_size: int = tf.data.experimental.AUTOTUNE,
      target_update_period: int = 100,
      importance_sampling_exponent: float = 0.2,
      priority_exponent: float = 0.6,
      epsilon: float = 0.01,
      learning_rate: float = 1e-3,
      min_replay_size: int = 1000,
      max_replay_size: int = 1000000,
      samples_per_insert: float = 32.0,
      store_lstm_state: bool = True,
      max_priority_weight: float = 0.9,
      checkpoint: bool = True,
  )

Source from the content-addressed store, hash-verified

44 """
45
46 def __init__(
47 self,
48 environment_spec: specs.EnvironmentSpec,
49 network: snt.RNNCore,
50 burn_in_length: int,
51 trace_length: int,
52 replay_period: int,
53 counter: Optional[counting.Counter] = None,
54 logger: Optional[loggers.Logger] = None,
55 discount: float = 0.99,
56 batch_size: int = 32,
57 prefetch_size: int = tf.data.experimental.AUTOTUNE,
58 target_update_period: int = 100,
59 importance_sampling_exponent: float = 0.2,
60 priority_exponent: float = 0.6,
61 epsilon: float = 0.01,
62 learning_rate: float = 1e-3,
63 min_replay_size: int = 1000,
64 max_replay_size: int = 1000000,
65 samples_per_insert: float = 32.0,
66 store_lstm_state: bool = True,
67 max_priority_weight: float = 0.9,
68 checkpoint: bool = True,
69 ):
70
71 if store_lstm_state:
72 extra_spec = {
73 'core_state': tf2_utils.squeeze_batch_dim(network.initial_state(1)),
74 }
75 else:
76 extra_spec = ()
77
78 sequence_length = burn_in_length + trace_length + 1
79 replay_table = reverb.Table(
80 name=adders.DEFAULT_PRIORITY_TABLE,
81 sampler=reverb.selectors.Prioritized(priority_exponent),
82 remover=reverb.selectors.Fifo(),
83 max_size=max_replay_size,
84 rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1),
85 signature=adders.SequenceAdder.signature(
86 environment_spec, extra_spec, sequence_length=sequence_length))
87 self._server = reverb.Server([replay_table], port=None)
88 address = f'localhost:{self._server.port}'
89
90 # Component to add things into replay.
91 adder = adders.SequenceAdder(
92 client=reverb.Client(address),
93 period=replay_period,
94 sequence_length=sequence_length,
95 )
96
97 # The dataset object to learn from.
98 dataset = datasets.make_reverb_dataset(
99 server_address=address,
100 batch_size=batch_size,
101 prefetch_size=prefetch_size)
102
103 target_network = copy.deepcopy(network)

Callers

nothing calls this directly

Calls 3

sampleMethod · 0.80
initial_stateMethod · 0.45
signatureMethod · 0.45

Tested by

no test coverage detected