| 40 | |
| 41 | |
| 42 | def log_gpu_memory_usage(head: str, logger: logging.Logger = None, level=logging.DEBUG, rank: int = 0): |
| 43 | if (not dist.is_initialized()) or (rank is None) or (dist.get_rank() == rank): |
| 44 | mem_allocated, mem_reserved, mem_used, mem_total = _get_current_mem_info() |
| 45 | message = f"{head}, memory allocated (GB): {mem_allocated}, memory reserved (GB): {mem_reserved}, device memory used/total (GB): {mem_used}/{mem_total}" |
| 46 | |
| 47 | if logger is None: |
| 48 | print(message) |
| 49 | else: |
| 50 | logger.log(msg=message, level=level) |
| 51 | |
| 52 | |
| 53 | class GPUMemoryLogger(DecoratorLoggerBase): |