(state, step)
| 105 | |
| 106 | @hvd.elastic.run |
| 107 | def train(state, step): |
| 108 | state.rendezvous += 1 |
| 109 | while state.epoch < args.epochs: |
| 110 | print('epoch {} batch {}'.format(state.epoch, state.batch)) |
| 111 | |
| 112 | while state.batch < args.batches_per_epoch: |
| 113 | check_exit(state.epoch, state.batch) |
| 114 | step() |
| 115 | |
| 116 | state.batch += 1 |
| 117 | if state.batch % args.batches_per_commit == 0: |
| 118 | state.commits += 1 |
| 119 | state.commit() |
| 120 | |
| 121 | if hvd.rank() == 0: |
| 122 | log_state(state) |
| 123 | |
| 124 | current_hosts = epoch_to_hosts.get(state.epoch, default_hosts) |
| 125 | next_hosts = epoch_to_hosts.get(state.epoch + 1, default_hosts) |
| 126 | if args.discovery_wait > 0 and current_hosts != next_hosts: |
| 127 | print('host changes: {} -> {}'.format(current_hosts, next_hosts)) |
| 128 | start = int(time.time()) |
| 129 | while state._host_messages.empty(): |
| 130 | if int(time.time()) - start > args.discovery_wait: |
| 131 | raise TimeoutError('Timed out waiting for notifications from driver.') |
| 132 | time.sleep(0.1) |
| 133 | |
| 134 | if args.epoch_wait > 0: |
| 135 | time.sleep(args.epoch_wait) |
| 136 | |
| 137 | state.epoch += 1 |
| 138 | state.batch = 0 |
| 139 | state.commits += 1 |
| 140 | state.commit() |
| 141 | |
| 142 | |
| 143 | with tf.Session(config=config) as session: |
no test coverage detected
searching dependent graphs…