(
self,
environment_spec: specs.EnvironmentSpec,
network: snt.RNNCore,
burn_in_length: int,
trace_length: int,
replay_period: int,
counter: Optional[counting.Counter] = None,
logger: Optional[loggers.Logger] = None,
discount: float = 0.99,
batch_size: int = 32,
prefetch_size: int = tf.data.experimental.AUTOTUNE,
target_update_period: int = 100,
importance_sampling_exponent: float = 0.2,
priority_exponent: float = 0.6,
epsilon: float = 0.01,
learning_rate: float = 1e-3,
min_replay_size: int = 1000,
max_replay_size: int = 1000000,
samples_per_insert: float = 32.0,
store_lstm_state: bool = True,
max_priority_weight: float = 0.9,
checkpoint: bool = True,
)
| 44 | """ |
| 45 | |
| 46 | def __init__( |
| 47 | self, |
| 48 | environment_spec: specs.EnvironmentSpec, |
| 49 | network: snt.RNNCore, |
| 50 | burn_in_length: int, |
| 51 | trace_length: int, |
| 52 | replay_period: int, |
| 53 | counter: Optional[counting.Counter] = None, |
| 54 | logger: Optional[loggers.Logger] = None, |
| 55 | discount: float = 0.99, |
| 56 | batch_size: int = 32, |
| 57 | prefetch_size: int = tf.data.experimental.AUTOTUNE, |
| 58 | target_update_period: int = 100, |
| 59 | importance_sampling_exponent: float = 0.2, |
| 60 | priority_exponent: float = 0.6, |
| 61 | epsilon: float = 0.01, |
| 62 | learning_rate: float = 1e-3, |
| 63 | min_replay_size: int = 1000, |
| 64 | max_replay_size: int = 1000000, |
| 65 | samples_per_insert: float = 32.0, |
| 66 | store_lstm_state: bool = True, |
| 67 | max_priority_weight: float = 0.9, |
| 68 | checkpoint: bool = True, |
| 69 | ): |
| 70 | |
| 71 | if store_lstm_state: |
| 72 | extra_spec = { |
| 73 | 'core_state': tf2_utils.squeeze_batch_dim(network.initial_state(1)), |
| 74 | } |
| 75 | else: |
| 76 | extra_spec = () |
| 77 | |
| 78 | sequence_length = burn_in_length + trace_length + 1 |
| 79 | replay_table = reverb.Table( |
| 80 | name=adders.DEFAULT_PRIORITY_TABLE, |
| 81 | sampler=reverb.selectors.Prioritized(priority_exponent), |
| 82 | remover=reverb.selectors.Fifo(), |
| 83 | max_size=max_replay_size, |
| 84 | rate_limiter=reverb.rate_limiters.MinSize(min_size_to_sample=1), |
| 85 | signature=adders.SequenceAdder.signature( |
| 86 | environment_spec, extra_spec, sequence_length=sequence_length)) |
| 87 | self._server = reverb.Server([replay_table], port=None) |
| 88 | address = f'localhost:{self._server.port}' |
| 89 | |
| 90 | # Component to add things into replay. |
| 91 | adder = adders.SequenceAdder( |
| 92 | client=reverb.Client(address), |
| 93 | period=replay_period, |
| 94 | sequence_length=sequence_length, |
| 95 | ) |
| 96 | |
| 97 | # The dataset object to learn from. |
| 98 | dataset = datasets.make_reverb_dataset( |
| 99 | server_address=address, |
| 100 | batch_size=batch_size, |
| 101 | prefetch_size=prefetch_size) |
| 102 | |
| 103 | target_network = copy.deepcopy(network) |
nothing calls this directly
no test coverage detected