(checkpoints_path, iteration, release=False, zero=False)
| 192 | |
| 193 | |
| 194 | def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False): |
| 195 | if release: |
| 196 | d = 'release' |
| 197 | else: |
| 198 | d = '{}'.format(iteration) |
| 199 | if zero: |
| 200 | dp_rank = mpu.get_data_parallel_rank() |
| 201 | d += '_zero_dp_rank_{}'.format(dp_rank) |
| 202 | return os.path.join(checkpoints_path, d, 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank())) |
| 203 | |
| 204 | |
| 205 | def ensure_directory_exists(filename): |
no outgoing calls
no test coverage detected