MCPcopy
hub / github.com/horovod/horovod / train

Function train

test/integration/data/elastic_tensorflow2_main.py:118–151  ·  view source on GitHub ↗
(state)

Source from the content-addressed store, hash-verified

116
117@hvd.elastic.run
118def train(state):
119 state.rendezvous += 1
120 while state.epoch < args.epochs:
121 print('epoch {} batch {}'.format(state.epoch, state.batch))
122
123 while state.batch < args.batches_per_epoch:
124 check_exit(state.epoch, state.batch)
125 step()
126
127 state.batch += 1
128 if state.batch % args.batches_per_commit == 0:
129 state.commits += 1
130 state.commit()
131
132 if hvd.rank() == 0:
133 log_state(state)
134
135 current_hosts = epoch_to_hosts.get(state.epoch, default_hosts)
136 next_hosts = epoch_to_hosts.get(state.epoch + 1, default_hosts)
137 if args.discovery_wait > 0 and current_hosts != next_hosts:
138 print('host changes: {} -> {}'.format(current_hosts, next_hosts))
139 start = int(time.time())
140 while state._host_messages.empty():
141 if int(time.time()) - start > args.discovery_wait:
142 raise TimeoutError('Timed out waiting for notifications from driver.')
143 time.sleep(0.1)
144
145 if args.epoch_wait > 0:
146 time.sleep(args.epoch_wait)
147
148 state.epoch += 1
149 state.batch = 0
150 state.commits += 1
151 state.commit()
152
153
154def on_state_reset():

Callers 1

Calls 8

stepFunction · 0.85
timeMethod · 0.80
check_exitFunction · 0.70
log_stateFunction · 0.70
commitMethod · 0.45
rankMethod · 0.45
getMethod · 0.45
sleepMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…