MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / _check_mem_usage

Function _check_mem_usage

tests/integration/defs/test_e2e.py:129–159  ·  view source on GitHub ↗
(file, mem_info, ranks_num=1)

Source from the content-addressed store, hash-verified

127
128
129def _check_mem_usage(file, mem_info, ranks_num=1):
130 if file is None or not TEST_MEM_USAGE:
131 return
132 delta = 0.3 # 0.3 GB as buffer
133 peak, model_size, kv_mem_size, extra, fraction, total_memory, activation_memory, tmp_kv, start_time_mem = _get_mem_info_from_log(
134 file, ranks_num)
135
136 peak = max(peak)
137 min_total = min(total_memory)
138 e_peak, e_model_size, e_kv_mem_size, e_extra = mem_info
139 import torch
140 _, total = torch.cuda.mem_get_info()
141 e_kv_mem_size = _get_kv_mem_size_candidate(min_total,
142 (e_peak + start_time_mem),
143 fraction)
144 print(
145 f"Expected memory usage: peak mem {e_peak + start_time_mem}, model mem {e_model_size}, kv mem {e_kv_mem_size:.2f}, extra {e_extra}, total {total / (1 << 30):.2f}"
146 )
147 print(
148 f"Running memory information: peak mem {peak}, model mem {model_size}, kv mem {kv_mem_size}, extra {extra}, total {min_total}, activation {activation_memory}, tmp_kv {tmp_kv}, fraction {fraction}, none-torch memory at starttime {start_time_mem}"
149 )
150
151 increased_peak_mem = peak - tmp_kv - e_peak - start_time_mem - delta
152 assert increased_peak_mem <= 0, (
153 f"increased peak memory {increased_peak_mem} is larger than 0,"
154 f" which is calculated as peak ({peak}) - tmp_kv ({tmp_kv}) -"
155 f" e_peak ({e_peak}) - start_time_mem ({start_time_mem}) - delta ({delta})."
156 )
157 assert kv_mem_size >= e_kv_mem_size - delta, f"kv memory size {kv_mem_size} is smaller than expected {e_kv_mem_size}"
158 # assert model_size <= e_model_size + delta, f"model memory {model_size} is larger than expected {e_model_size}"
159 # assert max(extra) <= e_extra + delta, f"extra memory size {extra} is larger than expected {e_extra}"
160
161
162def test_gpt3_175b_1layers_build_only(llm_root, llm_venv, engine_dir):

Calls 3

_get_mem_info_from_logFunction · 0.85
maxFunction · 0.85

Tested by

no test coverage detected