MCPcopy
hub / github.com/horovod/horovod / train

Function train

test/integration/data/elastic_tensorflow_main.py:107–140  ·  view source on GitHub ↗
(state, step)

Source from the content-addressed store, hash-verified

105
106@hvd.elastic.run
107def train(state, step):
108 state.rendezvous += 1
109 while state.epoch < args.epochs:
110 print('epoch {} batch {}'.format(state.epoch, state.batch))
111
112 while state.batch < args.batches_per_epoch:
113 check_exit(state.epoch, state.batch)
114 step()
115
116 state.batch += 1
117 if state.batch % args.batches_per_commit == 0:
118 state.commits += 1
119 state.commit()
120
121 if hvd.rank() == 0:
122 log_state(state)
123
124 current_hosts = epoch_to_hosts.get(state.epoch, default_hosts)
125 next_hosts = epoch_to_hosts.get(state.epoch + 1, default_hosts)
126 if args.discovery_wait > 0 and current_hosts != next_hosts:
127 print('host changes: {} -> {}'.format(current_hosts, next_hosts))
128 start = int(time.time())
129 while state._host_messages.empty():
130 if int(time.time()) - start > args.discovery_wait:
131 raise TimeoutError('Timed out waiting for notifications from driver.')
132 time.sleep(0.1)
133
134 if args.epoch_wait > 0:
135 time.sleep(args.epoch_wait)
136
137 state.epoch += 1
138 state.batch = 0
139 state.commits += 1
140 state.commit()
141
142
143with tf.Session(config=config) as session:

Callers 1

Calls 8

stepFunction · 0.85
timeMethod · 0.80
check_exitFunction · 0.70
log_stateFunction · 0.70
commitMethod · 0.45
rankMethod · 0.45
getMethod · 0.45
sleepMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…