()
| 296 | |
| 297 | |
| 298 | def create_deepspeed_args(): |
| 299 | parser = argparse.ArgumentParser() |
| 300 | args = parser.parse_args(args='') |
| 301 | args.deepspeed = True |
| 302 | if dist.is_initialized(): |
| 303 | # We assume up to one full node executing unit tests |
| 304 | assert dist.get_world_size() <= get_accelerator().device_count() |
| 305 | args.local_rank = dist.get_rank() |
| 306 | return args |
| 307 | |
| 308 | |
| 309 | def args_from_dict(tmpdir, config_dict): |
no test coverage detected
searching dependent graphs…