MCPcopy
hub / github.com/horovod/horovod / train

Function train

test/integration/test_spark_torch.py:530–567  ·  view source on GitHub ↗
(state, dir)

Source from the content-addressed store, hash-verified

528def fn(batches_per_commit, batches_per_epoch, epochs, dir=None):
529 @run
530 def train(state, dir):
531 state.rendezvous += 1
532 logging.info('rank %s: rendezvous %s', hvd.rank(), state.rendezvous)
533
534 for state.epoch in range(state.epoch, epochs):
535 logging.info('rank %s: start epoch %s at batch %s', hvd.rank(), state.epoch, state.batch)
536
537 for state.batch in range(state.batch, batches_per_epoch):
538 check_fail(dir, hvd.rank(), state.epoch, state.batch)
539
540 optimizer.zero_grad()
541 output = model(data)
542 loss = F.cross_entropy(output, target)
543 loss.backward()
544 optimizer.step()
545
546 # TODO: this sleep makes the fault tolerant test fail
547 # torch all gather throws an RuntimeError which should be a HorovodInternalError
548 #import time
549 #time.sleep(0.2)
550
551 if state.batch % batches_per_commit == 0:
552 logging.info('rank %s: allgather', hvd.rank())
553 hvd.allgather(torch.tensor([hvd.rank(), state.epoch, state.batch, state.rendezvous]), 'state').tolist()
554 logging.info('rank %s: commit epoch %s batch %s', hvd.rank(), state.epoch, state.batch)
555 state.commits += 1
556 state.commit()
557
558 logging.info('rank %s: allgather', hvd.rank())
559 hvd.allgather(torch.tensor([hvd.rank(), state.epoch, state.batch, state.rendezvous]), 'state').tolist()
560 logging.info('rank %s: commit epoch %s', hvd.rank(), state.epoch)
561 state.commits += 1
562 state.commit()
563 state.batch = 0
564
565 res = hvd.allgather(torch.tensor([hvd.rank(), state.epoch, state.batch, state.rendezvous]), 'state').tolist()
566 logging.info('rank %s: returning', hvd.rank())
567 return res, hvd.rank()
568
569 logging.getLogger().setLevel(logging.DEBUG)
570 logging.basicConfig(format='%(asctime)-15s %(levelname)1.1s %(filename)s:%(lineno)d %(funcName)s() - %(message)s')

Callers 1

fnFunction · 0.70

Calls 6

check_failFunction · 0.70
rankMethod · 0.45
zero_gradMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45
commitMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…