(state)
| 116 | |
| 117 | @hvd.elastic.run |
| 118 | def train(state): |
| 119 | state.rendezvous += 1 |
| 120 | while state.epoch < args.epochs: |
| 121 | print('epoch {} batch {}'.format(state.epoch, state.batch)) |
| 122 | |
| 123 | while state.batch < args.batches_per_epoch: |
| 124 | check_exit(state.epoch, state.batch) |
| 125 | step() |
| 126 | |
| 127 | state.batch += 1 |
| 128 | if state.batch % args.batches_per_commit == 0: |
| 129 | state.commits += 1 |
| 130 | state.commit() |
| 131 | |
| 132 | if hvd.rank() == 0: |
| 133 | log_state(state) |
| 134 | |
| 135 | current_hosts = epoch_to_hosts.get(state.epoch, default_hosts) |
| 136 | next_hosts = epoch_to_hosts.get(state.epoch + 1, default_hosts) |
| 137 | if args.discovery_wait > 0 and current_hosts != next_hosts: |
| 138 | print('host changes: {} -> {}'.format(current_hosts, next_hosts)) |
| 139 | start = int(time.time()) |
| 140 | while state._host_messages.empty(): |
| 141 | if int(time.time()) - start > args.discovery_wait: |
| 142 | raise TimeoutError('Timed out waiting for notifications from driver.') |
| 143 | time.sleep(0.1) |
| 144 | |
| 145 | if args.epoch_wait > 0: |
| 146 | time.sleep(args.epoch_wait) |
| 147 | |
| 148 | state.epoch += 1 |
| 149 | state.batch = 0 |
| 150 | state.commits += 1 |
| 151 | state.commit() |
| 152 | |
| 153 | |
| 154 | def on_state_reset(): |
no test coverage detected
searching dependent graphs…