A class that listens for exceptions in children processes and propagates the tracebacks to the parent process.
| 130 | |
| 131 | |
| 132 | class ErrorHandler(object): |
| 133 | """A class that listens for exceptions in children processes and propagates |
| 134 | the tracebacks to the parent process.""" |
| 135 | |
| 136 | def __init__(self, error_queue): |
| 137 | """init error handler""" |
| 138 | import signal |
| 139 | import threading |
| 140 | |
| 141 | self.error_queue = error_queue |
| 142 | self.children_pids = [] |
| 143 | self.error_thread = threading.Thread(target=self.error_listener, daemon=True) |
| 144 | self.error_thread.start() |
| 145 | signal.signal(signal.SIGUSR1, self.signal_handler) |
| 146 | |
| 147 | def add_child(self, pid): |
| 148 | """error handler""" |
| 149 | self.children_pids.append(pid) |
| 150 | |
| 151 | def error_listener(self): |
| 152 | """error listener""" |
| 153 | (rank, original_trace) = self.error_queue.get() |
| 154 | self.error_queue.put((rank, original_trace)) |
| 155 | os.kill(os.getpid(), signal.SIGUSR1) |
| 156 | |
| 157 | def signal_handler(self, signalnum, stackframe): |
| 158 | """signal handler""" |
| 159 | for pid in self.children_pids: |
| 160 | os.kill(pid, signal.SIGINT) # kill children processes |
| 161 | (rank, original_trace) = self.error_queue.get() |
| 162 | msg = """\n\n-- Tracebacks above this line can probably |
| 163 | be ignored --\n\n""" |
| 164 | msg += original_trace |
| 165 | raise Exception(msg) |
| 166 | |
| 167 | |
| 168 | def spawned_train(process_fn, opt, device_id, error_queue): # noqa: E501 |