(self, func, num_gpus, assert_callback=None, save_all_ranks=False, *args, **kwargs)
| 311 | return script_res |
| 312 | |
| 313 | def start(self, func, num_gpus, assert_callback=None, save_all_ranks=False, *args, **kwargs): |
| 314 | from .torch_utils import _find_free_port |
| 315 | ip = socket.gethostbyname(socket.gethostname()) |
| 316 | if 'dist_start_cmd' in kwargs: |
| 317 | dist_start_cmd = kwargs.pop('dist_start_cmd') |
| 318 | else: |
| 319 | dist_start_cmd = '%s -m torch.distributed.launch --nproc_per_node=%d ' \ |
| 320 | '--master_addr=\'%s\' --master_port=%s' % (sys.executable, num_gpus, ip, _find_free_port()) |
| 321 | |
| 322 | return self._start( |
| 323 | dist_start_cmd=dist_start_cmd, |
| 324 | func=func, |
| 325 | num_gpus=num_gpus, |
| 326 | assert_callback=assert_callback, |
| 327 | save_all_ranks=save_all_ranks, |
| 328 | *args, |
| 329 | **kwargs) |
| 330 | |
| 331 | def clean_tmp(self, tmp_file_list): |
| 332 | for file in tmp_file_list: |
no test coverage detected