| 252 | """ |
| 253 | |
| 254 | def _start(self, dist_start_cmd, func, num_gpus, assert_callback=None, save_all_ranks=False, *args, **kwargs): |
| 255 | script_path = func.__code__.co_filename |
| 256 | script_dir, script_name = os.path.split(script_path) |
| 257 | script_name = os.path.splitext(script_name)[0] |
| 258 | func_name = func.__qualname__ |
| 259 | |
| 260 | func_params = [] |
| 261 | for arg in args: |
| 262 | if isinstance(arg, str): |
| 263 | arg = ('\'{}\''.format(arg)) |
| 264 | func_params.append(str(arg)) |
| 265 | |
| 266 | for k, v in kwargs.items(): |
| 267 | if isinstance(v, str): |
| 268 | v = ('\'{}\''.format(v)) |
| 269 | func_params.append('{}={}'.format(k, v)) |
| 270 | |
| 271 | func_params = ','.join(func_params).strip(',') |
| 272 | |
| 273 | tmp_run_file = tempfile.NamedTemporaryFile(suffix='.py').name |
| 274 | tmp_res_file = tempfile.NamedTemporaryFile(suffix='.pkl').name |
| 275 | |
| 276 | with open(tmp_run_file, 'w') as f: |
| 277 | print('save temporary run file to : {}'.format(tmp_run_file)) |
| 278 | print('save results to : {}'.format(tmp_res_file)) |
| 279 | run_file_content = _DIST_SCRIPT_TEMPLATE.format(script_name, script_name, func_name, func_params) |
| 280 | f.write(run_file_content) |
| 281 | |
| 282 | tmp_res_files = [] |
| 283 | if save_all_ranks: |
| 284 | for i in range(num_gpus): |
| 285 | tmp_res_files.append(tmp_res_file + str(i)) |
| 286 | else: |
| 287 | tmp_res_files = [tmp_res_file] |
| 288 | self.addCleanup(self.clean_tmp, [tmp_run_file] + tmp_res_files) |
| 289 | |
| 290 | tmp_env = copy.deepcopy(os.environ) |
| 291 | tmp_env['PYTHONPATH'] = ':'.join((tmp_env.get('PYTHONPATH', ''), script_dir)).lstrip(':') |
| 292 | # avoid distributed test hang |
| 293 | tmp_env['NCCL_P2P_DISABLE'] = '1' |
| 294 | script_params = '--save_all_ranks=%s --save_file=%s' % (save_all_ranks, tmp_res_file) |
| 295 | script_cmd = '%s %s %s' % (dist_start_cmd, tmp_run_file, script_params) |
| 296 | print('script command: %s' % script_cmd) |
| 297 | res = subprocess.call(script_cmd, shell=True, env=tmp_env) |
| 298 | |
| 299 | script_res = [] |
| 300 | for res_file in tmp_res_files: |
| 301 | with open(res_file, 'rb') as f: |
| 302 | script_res.append(pickle.load(f)) |
| 303 | if not save_all_ranks: |
| 304 | script_res = script_res[0] |
| 305 | |
| 306 | if assert_callback: |
| 307 | assert_callback(script_res) |
| 308 | |
| 309 | self.assertEqual(res, 0, msg='The test function ``{}`` in ``{}`` run failed!'.format(func_name, script_name)) |
| 310 | |
| 311 | return script_res |