Checks that a given script generates the same output as the Python version using the given inputs.
(self,
script,
inputs,
name='func',
optimize=True,
inputs_requires_grad=False,
capture_output=False,
frames_up=1,
profiling=ProfilingMode.PROFILING,
atol=None,
rtol=None)
| 443 | self.assertEqual(bailout_outputs, expected) |
| 444 | |
| 445 | def checkScript(self, |
| 446 | script, |
| 447 | inputs, |
| 448 | name='func', |
| 449 | optimize=True, |
| 450 | inputs_requires_grad=False, |
| 451 | capture_output=False, |
| 452 | frames_up=1, |
| 453 | profiling=ProfilingMode.PROFILING, |
| 454 | atol=None, |
| 455 | rtol=None): |
| 456 | """ |
| 457 | Checks that a given script generates the same output as the Python |
| 458 | version using the given inputs. |
| 459 | """ |
| 460 | with torch.jit.optimized_execution(optimize): |
| 461 | with enable_profiling_mode_for_profiling_tests(): |
| 462 | extra_profile_runs = any(isinstance(x, torch.Tensor) and x.requires_grad for x in inputs) |
| 463 | if isinstance(script, str): |
| 464 | # Compile the string to a Script function |
| 465 | # with enable_profiling_mode(): |
| 466 | cu = torch.jit.CompilationUnit(script, _frames_up=frames_up) |
| 467 | |
| 468 | # Execute the Python function so we can run it later and get its |
| 469 | # outputs |
| 470 | |
| 471 | frame = self.get_frame_vars(frames_up) |
| 472 | the_locals: Dict[str, Any] = {} |
| 473 | execWrapper(script, glob=frame, loc=the_locals) |
| 474 | frame.update(the_locals) |
| 475 | |
| 476 | python_fn = frame[name] |
| 477 | scripted_fn = getattr(cu, name) |
| 478 | else: |
| 479 | |
| 480 | # Check the string frontend first |
| 481 | source = textwrap.dedent(inspect.getsource(script)) |
| 482 | self.checkScript( |
| 483 | source, |
| 484 | inputs, |
| 485 | script.__name__, |
| 486 | optimize=optimize, |
| 487 | inputs_requires_grad=inputs_requires_grad, |
| 488 | capture_output=capture_output, |
| 489 | profiling=profiling, |
| 490 | frames_up=2) |
| 491 | |
| 492 | # Continue checking the Python frontend |
| 493 | scripted_fn = torch.jit.script(script, _frames_up=1) |
| 494 | python_fn = script |
| 495 | |
| 496 | if inputs_requires_grad: |
| 497 | recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs) |
| 498 | else: |
| 499 | recording_inputs = inputs |
| 500 | |
| 501 | if capture_output: |
| 502 | with self.capture_stdout() as script_stdout: |