To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. Args: test_case: The test case object that will run `target_func`. target_func (`Callable`): The function implementing the actual testing logic. inputs (`dict`,
(test_case, target_func, inputs=None, timeout=None)
| 1275 | |
| 1276 | # Taken from: https://github.com/huggingface/transformers/blob/3658488ff77ff8d45101293e749263acf437f4d5/src/transformers..testing_utils.py#L1787 |
| 1277 | def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): |
| 1278 | """ |
| 1279 | To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. |
| 1280 | |
| 1281 | Args: |
| 1282 | test_case: |
| 1283 | The test case object that will run `target_func`. |
| 1284 | target_func (`Callable`): |
| 1285 | The function implementing the actual testing logic. |
| 1286 | inputs (`dict`, *optional*, defaults to `None`): |
| 1287 | The inputs that will be passed to `target_func` through an (input) queue. |
| 1288 | timeout (`int`, *optional*, defaults to `None`): |
| 1289 | The timeout (in seconds) that will be passed to the input and output queues. If not specified, the env. |
| 1290 | variable `PYTEST_TIMEOUT` will be checked. If still `None`, its value will be set to `600`. |
| 1291 | """ |
| 1292 | if timeout is None: |
| 1293 | timeout = int(os.environ.get("PYTEST_TIMEOUT", 600)) |
| 1294 | |
| 1295 | start_methohd = "spawn" |
| 1296 | ctx = multiprocessing.get_context(start_methohd) |
| 1297 | |
| 1298 | input_queue = ctx.Queue(1) |
| 1299 | output_queue = ctx.JoinableQueue(1) |
| 1300 | |
| 1301 | # We can't send test case objects to the child, otherwise we get issues regarding pickle. |
| 1302 | input_queue.put(inputs, timeout=timeout) |
| 1303 | |
| 1304 | process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout)) |
| 1305 | process.start() |
| 1306 | # Kill the child process if we can't get outputs from it in time: otherwise, the hanging subprocess prevents |
| 1307 | # the test to exit properly. |
| 1308 | try: |
| 1309 | results = output_queue.get(timeout=timeout) |
| 1310 | output_queue.task_done() |
| 1311 | except Exception as e: |
| 1312 | process.terminate() |
| 1313 | test_case.fail(e) |
| 1314 | process.join(timeout=timeout) |
| 1315 | |
| 1316 | if results["error"] is not None: |
| 1317 | test_case.fail(f"{results['error']}") |
| 1318 | |
| 1319 | |
| 1320 | class CaptureLogger: |
no test coverage detected
searching dependent graphs…