(streaming)
| 95 | @pytest.mark.skipif(os.name == "nt", reason="execute_subprocess_async doesn't support windows") |
| 96 | @pytest.mark.integration |
| 97 | def test_torch_distributed_run(streaming): |
| 98 | nproc_per_node = 2 |
| 99 | master_port = get_torch_dist_unique_port() |
| 100 | test_script = Path(__file__).resolve().parent / "distributed_scripts" / "run_torch_distributed.py" |
| 101 | distributed_args = f""" |
| 102 | -m torch.distributed.run |
| 103 | --nproc_per_node={nproc_per_node} |
| 104 | --master_port={master_port} |
| 105 | {test_script} |
| 106 | """.split() |
| 107 | args = f""" |
| 108 | --streaming={streaming} |
| 109 | """.split() |
| 110 | cmd = [sys.executable] + distributed_args + args |
| 111 | execute_subprocess_async(cmd, env=os.environ.copy()) |
| 112 | |
| 113 | |
| 114 | @pytest.mark.parametrize( |
nothing calls this directly
no test coverage detected