Work loop in the worker
(rpc_config, mp_contexts)
| 69 | |
| 70 | |
| 71 | def init_process(rpc_config, mp_contexts): |
| 72 | """Work loop in the worker""" |
| 73 | try: |
| 74 | _init_rpc(*rpc_config) |
| 75 | keep_polling = True |
| 76 | data_queue, task_queue, barrier = mp_contexts |
| 77 | collate_fn_dict = {} |
| 78 | |
| 79 | while keep_polling: |
| 80 | try: |
| 81 | # Follow https://github.com/pytorch/pytorch/blob/d57ce8cf8989c0b737e636d8d7abe16c1f08f70b/torch/utils/data/_utils/worker.py#L260 |
| 82 | command, args = task_queue.get(timeout=5) |
| 83 | except queue.Empty: |
| 84 | continue |
| 85 | if command == MpCommand.SET_COLLATE_FN: |
| 86 | dataloader_name, func = args |
| 87 | collate_fn_dict[dataloader_name] = func |
| 88 | elif command == MpCommand.CALL_BARRIER: |
| 89 | barrier.wait() |
| 90 | elif command == MpCommand.DELETE_COLLATE_FN: |
| 91 | (dataloader_name,) = args |
| 92 | del collate_fn_dict[dataloader_name] |
| 93 | elif command == MpCommand.CALL_COLLATE_FN: |
| 94 | dataloader_name, collate_args = args |
| 95 | data_queue.put( |
| 96 | ( |
| 97 | dataloader_name, |
| 98 | collate_fn_dict[dataloader_name](collate_args), |
| 99 | ) |
| 100 | ) |
| 101 | elif command == MpCommand.CALL_FN_ALL_WORKERS: |
| 102 | func, func_args = args |
| 103 | func(func_args) |
| 104 | elif command == MpCommand.FINALIZE_POOL: |
| 105 | _exit() |
| 106 | keep_polling = False |
| 107 | else: |
| 108 | raise Exception("Unknown command") |
| 109 | except Exception as e: |
| 110 | traceback.print_exc() |
| 111 | raise e |
| 112 | |
| 113 | |
| 114 | class CustomPool: |