MCPcopy Index your code
hub / github.com/pytorch/pytorch / checkScript

Method checkScript

torch/testing/_internal/jit_utils.py:445–523  ·  view source on GitHub ↗

Checks that a given script generates the same output as the Python version using the given inputs.

(self,
                    script,
                    inputs,
                    name='func',
                    optimize=True,
                    inputs_requires_grad=False,
                    capture_output=False,
                    frames_up=1,
                    profiling=ProfilingMode.PROFILING,
                    atol=None,
                    rtol=None)

Source from the content-addressed store, hash-verified

443 self.assertEqual(bailout_outputs, expected)
444
445 def checkScript(self,
446 script,
447 inputs,
448 name='func',
449 optimize=True,
450 inputs_requires_grad=False,
451 capture_output=False,
452 frames_up=1,
453 profiling=ProfilingMode.PROFILING,
454 atol=None,
455 rtol=None):
456 """
457 Checks that a given script generates the same output as the Python
458 version using the given inputs.
459 """
460 with torch.jit.optimized_execution(optimize):
461 with enable_profiling_mode_for_profiling_tests():
462 extra_profile_runs = any(isinstance(x, torch.Tensor) and x.requires_grad for x in inputs)
463 if isinstance(script, str):
464 # Compile the string to a Script function
465 # with enable_profiling_mode():
466 cu = torch.jit.CompilationUnit(script, _frames_up=frames_up)
467
468 # Execute the Python function so we can run it later and get its
469 # outputs
470
471 frame = self.get_frame_vars(frames_up)
472 the_locals: Dict[str, Any] = {}
473 execWrapper(script, glob=frame, loc=the_locals)
474 frame.update(the_locals)
475
476 python_fn = frame[name]
477 scripted_fn = getattr(cu, name)
478 else:
479
480 # Check the string frontend first
481 source = textwrap.dedent(inspect.getsource(script))
482 self.checkScript(
483 source,
484 inputs,
485 script.__name__,
486 optimize=optimize,
487 inputs_requires_grad=inputs_requires_grad,
488 capture_output=capture_output,
489 profiling=profiling,
490 frames_up=2)
491
492 # Continue checking the Python frontend
493 scripted_fn = torch.jit.script(script, _frames_up=1)
494 python_fn = script
495
496 if inputs_requires_grad:
497 recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs)
498 else:
499 recording_inputs = inputs
500
501 if capture_output:
502 with self.capture_stdout() as script_stdout:

Callers 15

test_typecheckMethod · 0.80
test_sum_simpleMethod · 0.80
test_sum_dimMethod · 0.80
test_sum_keepdim_castMethod · 0.80
test_absMethod · 0.80
test_broadcastMethod · 0.80
test_chunkMethod · 0.80

Calls 11

get_frame_varsMethod · 0.95
checkBailoutsMethod · 0.95
isinstanceFunction · 0.85
do_input_mapFunction · 0.85
assertExpectedMethod · 0.80
execWrapperFunction · 0.70
anyFunction · 0.50
updateMethod · 0.45
requires_grad_Method · 0.45
assertEqualMethod · 0.45

Tested by 15

test_typecheckMethod · 0.64
test_sum_simpleMethod · 0.64
test_sum_dimMethod · 0.64
test_sum_keepdim_castMethod · 0.64
test_absMethod · 0.64
test_broadcastMethod · 0.64
test_chunkMethod · 0.64