(file, mem_info, ranks_num=1)
| 127 | |
| 128 | |
| 129 | def _check_mem_usage(file, mem_info, ranks_num=1): |
| 130 | if file is None or not TEST_MEM_USAGE: |
| 131 | return |
| 132 | delta = 0.3 # 0.3 GB as buffer |
| 133 | peak, model_size, kv_mem_size, extra, fraction, total_memory, activation_memory, tmp_kv, start_time_mem = _get_mem_info_from_log( |
| 134 | file, ranks_num) |
| 135 | |
| 136 | peak = max(peak) |
| 137 | min_total = min(total_memory) |
| 138 | e_peak, e_model_size, e_kv_mem_size, e_extra = mem_info |
| 139 | import torch |
| 140 | _, total = torch.cuda.mem_get_info() |
| 141 | e_kv_mem_size = _get_kv_mem_size_candidate(min_total, |
| 142 | (e_peak + start_time_mem), |
| 143 | fraction) |
| 144 | print( |
| 145 | f"Expected memory usage: peak mem {e_peak + start_time_mem}, model mem {e_model_size}, kv mem {e_kv_mem_size:.2f}, extra {e_extra}, total {total / (1 << 30):.2f}" |
| 146 | ) |
| 147 | print( |
| 148 | f"Running memory information: peak mem {peak}, model mem {model_size}, kv mem {kv_mem_size}, extra {extra}, total {min_total}, activation {activation_memory}, tmp_kv {tmp_kv}, fraction {fraction}, none-torch memory at starttime {start_time_mem}" |
| 149 | ) |
| 150 | |
| 151 | increased_peak_mem = peak - tmp_kv - e_peak - start_time_mem - delta |
| 152 | assert increased_peak_mem <= 0, ( |
| 153 | f"increased peak memory {increased_peak_mem} is larger than 0," |
| 154 | f" which is calculated as peak ({peak}) - tmp_kv ({tmp_kv}) -" |
| 155 | f" e_peak ({e_peak}) - start_time_mem ({start_time_mem}) - delta ({delta})." |
| 156 | ) |
| 157 | assert kv_mem_size >= e_kv_mem_size - delta, f"kv memory size {kv_mem_size} is smaller than expected {e_kv_mem_size}" |
| 158 | # assert model_size <= e_model_size + delta, f"model memory {model_size} is larger than expected {e_model_size}" |
| 159 | # assert max(extra) <= e_extra + delta, f"extra memory size {extra} is larger than expected {e_extra}" |
| 160 | |
| 161 | |
| 162 | def test_gpt3_175b_1layers_build_only(llm_root, llm_venv, engine_dir): |
no test coverage detected