| 259 | |
| 260 | |
| 261 | class VRAMPeakMonitor: |
| 262 | |
| 263 | def __init__(self, tag: str) -> None: |
| 264 | self.tag = tag |
| 265 | |
| 266 | def __enter__(self): |
| 267 | self.peak_before = torch.cuda.max_memory_allocated() / (1024**3) |
| 268 | return self |
| 269 | |
| 270 | def __exit__(self, exc_type, exc_value, traceback): |
| 271 | torch.cuda.synchronize() |
| 272 | peak_after = torch.cuda.max_memory_allocated() / (1024**3) |
| 273 | YELLOW = "\033[93m" |
| 274 | RESET = "\033[0m" |
| 275 | if TRACE_VRAM: |
| 276 | print( |
| 277 | f"{YELLOW}VRAM peak before {self.tag}: {self.peak_before:.2f} GB, " |
| 278 | f"after: {peak_after:.2f} GB{RESET}" |
| 279 | ) |
| 280 | return False |
| 281 | |
| 282 | |
| 283 | def log_txt_as_img(wh, xc): |
no outgoing calls
no test coverage detected