MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / MultiRunning

Class MultiRunning

LanguageNetwork/BERT/train.py:32–76  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

30
31
32class MultiRunning(object):
33 def __init__(self, args, device_id):
34 self.args = args
35 self.device_id = device_id
36
37 def multi_card_run(self):
38 """ Spawns 1 process per GPU """
39 init_logger()
40
41 nb_gpu = self.args.world_size
42 mp = torch.multiprocessing.get_context('spawn')
43
44 # Create a thread to listen for errors in the child processes.
45 error_queue = mp.SimpleQueue()
46 error_handler = ErrorHandler(error_queue)
47
48 # Train with multiprocessing.
49 process = []
50 for i in range(nb_gpu):
51 self.device_id = i
52 process.append(mp.Process(target=self.multi_card_train, args=(self.args, self.device_id, error_queue),
53 daemon=True))
54 process[i].start()
55 logger.info(" Starting process pid: %d " % process[i].pid)
56 error_handler.add_child(process[i].pid)
57 for p in process:
58 p.join()
59
60 def multi_card_train(self, error_queue):
61 """ run process """
62 setattr(self.args, 'gpu_ranks', [int(i) for i in self.args.gpu_ranks])
63
64 try:
65 gpu_rank = distributed.multi_init(self.device_id, self.args.world_size, self.args.gpu_ranks)
66 print('gpu_rank %d' % gpu_rank)
67 if gpu_rank != self.args.gpu_ranks[self.device_id]:
68 raise AssertionError("An error occurred in Distributed initialization")
69 runner = Running(self.args, self.device_id)
70 runner.train()
71 except KeyboardInterrupt:
72 pass # killed by parent, do nothing
73 except Exception:
74 # propagate exception to parent process, keeping original traceback
75 import traceback
76 error_queue.put((self.args.gpu_ranks[self.device_id], traceback.format_exc()))
77
78
79class ErrorHandler(object):

Callers 1

train.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected