(checkpoints_path, iteration, release=False, zero=False)
| 156 | |
| 157 | |
| 158 | def get_checkpoint_name(checkpoints_path, iteration, release=False, zero=False): |
| 159 | if release: |
| 160 | d = 'release' |
| 161 | else: |
| 162 | d = '{:d}'.format(iteration) |
| 163 | if zero: |
| 164 | dp_rank = mpu.get_data_parallel_rank() |
| 165 | d += '_zero_dp_rank_{}'.format(dp_rank) |
| 166 | return os.path.join(checkpoints_path, d, 'mp_rank_{:02d}_model_states.pt'.format(mpu.get_model_parallel_rank())) |
| 167 | |
| 168 | |
| 169 | def ensure_directory_exists(filename): |
no outgoing calls
no test coverage detected