This context manager is used to allow one process to execute while blocking all other processes in the same process group. This is often useful when downloading is required as we only want to download in one process to prevent file corruption. Args: executor_rank (int): the
| 5 | |
| 6 | |
| 7 | class barrier_context: |
| 8 | """ |
| 9 | This context manager is used to allow one process to execute while blocking all |
| 10 | other processes in the same process group. This is often useful when downloading is required |
| 11 | as we only want to download in one process to prevent file corruption. |
| 12 | Args: |
| 13 | executor_rank (int): the process rank to execute without blocking, all other processes will be blocked |
| 14 | parallel_mode (ParallelMode): the parallel mode corresponding to a process group |
| 15 | Usage: |
| 16 | with barrier_context(): |
| 17 | dataset = CIFAR10(root='./data', download=True) |
| 18 | """ |
| 19 | |
| 20 | def __init__(self, executor_rank: int = 0, parallel_mode: ParallelMode = ParallelMode.GLOBAL): |
| 21 | # the class name is lowercase by convention |
| 22 | current_rank = gpc.get_local_rank(parallel_mode=parallel_mode) |
| 23 | self.should_block = current_rank != executor_rank |
| 24 | self.group = gpc.get_group(parallel_mode=parallel_mode) |
| 25 | |
| 26 | def __enter__(self): |
| 27 | if self.should_block: |
| 28 | dist.barrier(group=self.group) |
| 29 | |
| 30 | def __exit__(self, exc_type, exc_value, exc_traceback): |
| 31 | if not self.should_block: |
| 32 | dist.barrier(group=self.group) |
no outgoing calls
no test coverage detected
searching dependent graphs…