(_)
| 125 | |
| 126 | |
| 127 | def main(_): |
| 128 | # Create an environment and grab the spec. |
| 129 | raw_environment = bsuite.load_and_record_to_csv( |
| 130 | bsuite_id=FLAGS.bsuite_id, |
| 131 | results_dir=FLAGS.results_dir, |
| 132 | overwrite=FLAGS.overwrite, |
| 133 | ) |
| 134 | environment = single_precision.SinglePrecisionWrapper(raw_environment) |
| 135 | environment_spec = specs.make_environment_spec(environment) |
| 136 | |
| 137 | # Build demonstration dataset. |
| 138 | if hasattr(raw_environment, 'raw_env'): |
| 139 | raw_environment = raw_environment.raw_env |
| 140 | |
| 141 | batch_dataset = bsuite_demonstrations.make_dataset(raw_environment, |
| 142 | stochastic=False) |
| 143 | # Combine with demonstration dataset. |
| 144 | transition = functools.partial( |
| 145 | _n_step_transition_from_episode, n_step=1, additional_discount=1.) |
| 146 | |
| 147 | dataset = batch_dataset.map(transition) |
| 148 | |
| 149 | # Batch and prefetch. |
| 150 | dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True) |
| 151 | dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) |
| 152 | |
| 153 | # Create the networks to optimize. |
| 154 | policy_network = make_policy_network(environment_spec.actions) |
| 155 | |
| 156 | # If the agent is non-autoregressive use epsilon=0 which will be a greedy |
| 157 | # policy. |
| 158 | evaluator_network = snt.Sequential([ |
| 159 | policy_network, |
| 160 | lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(), |
| 161 | ]) |
| 162 | |
| 163 | # Ensure that we create the variables before proceeding (maybe not needed). |
| 164 | tf2_utils.create_variables(policy_network, [environment_spec.observations]) |
| 165 | |
| 166 | counter = counting.Counter() |
| 167 | learner_counter = counting.Counter(counter, prefix='learner') |
| 168 | |
| 169 | # Create the actor which defines how we take actions. |
| 170 | evaluation_network = actors.FeedForwardActor(evaluator_network) |
| 171 | |
| 172 | eval_loop = acme.EnvironmentLoop( |
| 173 | environment=environment, |
| 174 | actor=evaluation_network, |
| 175 | counter=counter, |
| 176 | logger=loggers.TerminalLogger('evaluation', time_delta=1.)) |
| 177 | |
| 178 | # The learner updates the parameters (and initializes them). |
| 179 | learner = learning.BCLearner( |
| 180 | network=policy_network, |
| 181 | learning_rate=FLAGS.learning_rate, |
| 182 | dataset=dataset, |
| 183 | counter=learner_counter) |
| 184 |
nothing calls this directly
no test coverage detected