MCPcopy
hub / github.com/google-deepmind/acme / main

Function main

examples/offline/run_bc.py:127–190  ·  view source on GitHub ↗
(_)

Source from the content-addressed store, hash-verified

125
126
127def main(_):
128 # Create an environment and grab the spec.
129 raw_environment = bsuite.load_and_record_to_csv(
130 bsuite_id=FLAGS.bsuite_id,
131 results_dir=FLAGS.results_dir,
132 overwrite=FLAGS.overwrite,
133 )
134 environment = single_precision.SinglePrecisionWrapper(raw_environment)
135 environment_spec = specs.make_environment_spec(environment)
136
137 # Build demonstration dataset.
138 if hasattr(raw_environment, 'raw_env'):
139 raw_environment = raw_environment.raw_env
140
141 batch_dataset = bsuite_demonstrations.make_dataset(raw_environment,
142 stochastic=False)
143 # Combine with demonstration dataset.
144 transition = functools.partial(
145 _n_step_transition_from_episode, n_step=1, additional_discount=1.)
146
147 dataset = batch_dataset.map(transition)
148
149 # Batch and prefetch.
150 dataset = dataset.batch(FLAGS.batch_size, drop_remainder=True)
151 dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
152
153 # Create the networks to optimize.
154 policy_network = make_policy_network(environment_spec.actions)
155
156 # If the agent is non-autoregressive use epsilon=0 which will be a greedy
157 # policy.
158 evaluator_network = snt.Sequential([
159 policy_network,
160 lambda q: trfl.epsilon_greedy(q, epsilon=FLAGS.epsilon).sample(),
161 ])
162
163 # Ensure that we create the variables before proceeding (maybe not needed).
164 tf2_utils.create_variables(policy_network, [environment_spec.observations])
165
166 counter = counting.Counter()
167 learner_counter = counting.Counter(counter, prefix='learner')
168
169 # Create the actor which defines how we take actions.
170 evaluation_network = actors.FeedForwardActor(evaluator_network)
171
172 eval_loop = acme.EnvironmentLoop(
173 environment=environment,
174 actor=evaluation_network,
175 counter=counter,
176 logger=loggers.TerminalLogger('evaluation', time_delta=1.))
177
178 # The learner updates the parameters (and initializes them).
179 learner = learning.BCLearner(
180 network=policy_network,
181 learning_rate=FLAGS.learning_rate,
182 dataset=dataset,
183 counter=learner_counter)
184

Callers

nothing calls this directly

Calls 5

stepMethod · 0.95
incrementMethod · 0.95
runMethod · 0.95
make_policy_networkFunction · 0.85
sampleMethod · 0.80

Tested by

no test coverage detected