Context manager for testing that code prints certain text. Examples -------- >>> with AssertPrints("abc", suppress=False): ... print("abcd") ... print("def") ... abcd def
| 309 | """ |
| 310 | |
| 311 | class AssertPrints: |
| 312 | """Context manager for testing that code prints certain text. |
| 313 | |
| 314 | Examples |
| 315 | -------- |
| 316 | >>> with AssertPrints("abc", suppress=False): |
| 317 | ... print("abcd") |
| 318 | ... print("def") |
| 319 | ... |
| 320 | abcd |
| 321 | def |
| 322 | """ |
| 323 | def __init__(self, s: str, channel: str='stdout', suppress: bool=True): |
| 324 | self.s = s |
| 325 | if isinstance(self.s, (str, _re_type)): |
| 326 | self.s = [self.s] |
| 327 | self.channel = channel |
| 328 | self.suppress = suppress |
| 329 | |
| 330 | def __enter__(self): |
| 331 | self.orig_stream = getattr(sys, self.channel) |
| 332 | self.buffer = MyStringIO() |
| 333 | self.tee = Tee(self.buffer, channel=self.channel) |
| 334 | setattr(sys, self.channel, self.buffer if self.suppress else self.tee) |
| 335 | |
| 336 | def __exit__(self, etype: Optional[Type[BaseException]], value: Optional[BaseException], traceback: Optional[TracebackType]): |
| 337 | __tracebackhide__ = True |
| 338 | |
| 339 | try: |
| 340 | if value is not None: |
| 341 | # If an error was raised, don't check anything else |
| 342 | return False |
| 343 | self.tee.flush() |
| 344 | setattr(sys, self.channel, self.orig_stream) |
| 345 | printed = self.buffer.getvalue() |
| 346 | for s in self.s: |
| 347 | if isinstance(s, _re_type): |
| 348 | assert s.search(printed), notprinted_msg.format(s.pattern, self.channel, printed) |
| 349 | else: |
| 350 | assert s in printed, notprinted_msg.format(s, self.channel, printed) |
| 351 | return False |
| 352 | finally: |
| 353 | self.tee.close() |
| 354 | |
| 355 | printed_msg = """Found {0!r} in printed output (on {1}): |
| 356 | ------- |
no outgoing calls
searching dependent graphs…