(garbage_collect=True)
| 9 | |
| 10 | |
| 11 | def flush(garbage_collect=True): |
| 12 | if torch.cuda.is_available(): |
| 13 | torch.cuda.empty_cache() |
| 14 | # if is mps, also clear the mps cache |
| 15 | if torch.backends.mps.is_available(): |
| 16 | torch.mps.empty_cache() |
| 17 | if garbage_collect: |
| 18 | gc.collect() |
| 19 | |
| 20 | |
| 21 | def get_mean_std(tensor): |
no outgoing calls
no test coverage detected