MCPcopy
hub / github.com/stas00/ml-engineering / detect_overflow

Function detect_overflow

debug/underflow_overflow.py:301–344  ·  view source on GitHub ↗

Report whether the tensor contains any `nan` or `inf` entries. This is useful for detecting overflows/underflows and best to call right after the function that did some math that modified the tensor in question. This function contains a few other helper features that you can enabl

(var, ctx)

Source from the content-addressed store, hash-verified

299
300
301def detect_overflow(var, ctx):
302 """
303 Report whether the tensor contains any `nan` or `inf` entries.
304
305 This is useful for detecting overflows/underflows and best to call right after the function that did some math that
306 modified the tensor in question.
307
308 This function contains a few other helper features that you can enable and tweak directly if you want to track
309 various other things.
310
311 Args:
312 var: the tensor variable to check
313 ctx: the message to print as a context
314
315 Return:
316 `True` if `inf` or `nan` was detected, `False` otherwise
317 """
318 detected = False
319 if torch.isnan(var).any().item():
320 detected = True
321 print(f"{ctx} has nans")
322 if torch.isinf(var).any().item():
323 detected = True
324 print(f"{ctx} has infs")
325
326 # if needed to monitor large elements can enable the following
327 if 0: # and detected:
328 n100 = var[torch.ge(var.abs(), 100)]
329 if n100.numel() > 0:
330 print(f"{ctx}: n100={n100.numel()}")
331 n1000 = var[torch.ge(var.abs(), 1000)]
332 if n1000.numel() > 0:
333 print(f"{ctx}: n1000={n1000.numel()}")
334 n10000 = var[torch.ge(var.abs(), 10000)]
335 if n10000.numel() > 0:
336 print(f"{ctx}: n10000={n10000.numel()}")
337
338 if 0:
339 print(f"min={var.min():9.2e} max={var.max():9.2e}")
340
341 if 0:
342 print(f"min={var.min():9.2e} max={var.max():9.2e} var={var.var():9.2e} mean={var.mean():9.2e} ({ctx})")
343
344 return detected
345
346
347class DebugOption(ExplicitEnum):

Callers 1

analyse_variableMethod · 0.85

Calls 1

printFunction · 0.85

Tested by

no test coverage detected