Checks that a given function will throw the correct exception, when executed with normal python, the string frontend, and the AST frontend. Logic taken from `checkScript` (see comments there for details)
(self, script, inputs, exception, regex,
name=None, outputs=None, capture_output=False,
frames_up=1, profiling=ProfilingMode.PROFILING)
| 392 | return _AssertRaisesRegexWithHighlightContext(self, exception, regex, highlight) |
| 393 | |
| 394 | def checkScriptRaisesRegex(self, script, inputs, exception, regex, |
| 395 | name=None, outputs=None, capture_output=False, |
| 396 | frames_up=1, profiling=ProfilingMode.PROFILING): |
| 397 | """ |
| 398 | Checks that a given function will throw the correct exception, |
| 399 | when executed with normal python, the string frontend, and the |
| 400 | AST frontend. Logic taken from `checkScript` (see comments there |
| 401 | for details) |
| 402 | """ |
| 403 | with enable_profiling_mode_for_profiling_tests(): |
| 404 | # Normal Python |
| 405 | with self.assertRaisesRegex(exception, regex): |
| 406 | if isinstance(script, str): |
| 407 | frame = self.get_frame_vars(frames_up) |
| 408 | the_locals: Dict[str, Any] = {} |
| 409 | execWrapper(script, glob=frame, loc=the_locals) |
| 410 | frame.update(the_locals) |
| 411 | |
| 412 | python_fn = frame[name] |
| 413 | else: |
| 414 | python_fn = script |
| 415 | |
| 416 | python_fn(*inputs) |
| 417 | |
| 418 | # String frontend |
| 419 | with self.assertRaisesRegex(exception, regex): |
| 420 | if isinstance(script, str): |
| 421 | cu = torch.jit.CompilationUnit(script, _frames_up=frames_up) |
| 422 | string_frontend = getattr(cu, name) |
| 423 | else: |
| 424 | source = textwrap.dedent(inspect.getsource(script)) |
| 425 | cu = torch.jit.CompilationUnit(source, _frames_up=frames_up) |
| 426 | string_frontend = getattr(cu, script.__name__) |
| 427 | |
| 428 | string_frontend(*inputs) |
| 429 | |
| 430 | # Python AST frontend |
| 431 | if not isinstance(script, str): |
| 432 | with self.assertRaisesRegex(exception, regex): |
| 433 | ge = torch.jit.script(python_fn) |
| 434 | ge(*inputs) |
| 435 | |
| 436 | def checkBailouts(self, model, inputs, expected): |
| 437 | state = model.get_debug_state() |