| 528 | def fn(batches_per_commit, batches_per_epoch, epochs, dir=None): |
| 529 | @run |
| 530 | def train(state, dir): |
| 531 | state.rendezvous += 1 |
| 532 | logging.info('rank %s: rendezvous %s', hvd.rank(), state.rendezvous) |
| 533 | |
| 534 | for state.epoch in range(state.epoch, epochs): |
| 535 | logging.info('rank %s: start epoch %s at batch %s', hvd.rank(), state.epoch, state.batch) |
| 536 | |
| 537 | for state.batch in range(state.batch, batches_per_epoch): |
| 538 | check_fail(dir, hvd.rank(), state.epoch, state.batch) |
| 539 | |
| 540 | optimizer.zero_grad() |
| 541 | output = model(data) |
| 542 | loss = F.cross_entropy(output, target) |
| 543 | loss.backward() |
| 544 | optimizer.step() |
| 545 | |
| 546 | # TODO: this sleep makes the fault tolerant test fail |
| 547 | # torch all gather throws an RuntimeError which should be a HorovodInternalError |
| 548 | #import time |
| 549 | #time.sleep(0.2) |
| 550 | |
| 551 | if state.batch % batches_per_commit == 0: |
| 552 | logging.info('rank %s: allgather', hvd.rank()) |
| 553 | hvd.allgather(torch.tensor([hvd.rank(), state.epoch, state.batch, state.rendezvous]), 'state').tolist() |
| 554 | logging.info('rank %s: commit epoch %s batch %s', hvd.rank(), state.epoch, state.batch) |
| 555 | state.commits += 1 |
| 556 | state.commit() |
| 557 | |
| 558 | logging.info('rank %s: allgather', hvd.rank()) |
| 559 | hvd.allgather(torch.tensor([hvd.rank(), state.epoch, state.batch, state.rendezvous]), 'state').tolist() |
| 560 | logging.info('rank %s: commit epoch %s', hvd.rank(), state.epoch) |
| 561 | state.commits += 1 |
| 562 | state.commit() |
| 563 | state.batch = 0 |
| 564 | |
| 565 | res = hvd.allgather(torch.tensor([hvd.rank(), state.epoch, state.batch, state.rendezvous]), 'state').tolist() |
| 566 | logging.info('rank %s: returning', hvd.rank()) |
| 567 | return res, hvd.rank() |
| 568 | |
| 569 | logging.getLogger().setLevel(logging.DEBUG) |
| 570 | logging.basicConfig(format='%(asctime)-15s %(levelname)1.1s %(filename)s:%(lineno)d %(funcName)s() - %(message)s') |