(trainer, num_microbatches, tensor_shape, metric_hook_list)
| 32 | |
| 33 | @contextmanager |
| 34 | def switch_evaluation_pipeline_scheduler(trainer, num_microbatches, tensor_shape, metric_hook_list): |
| 35 | if gpc.is_using_pp(): |
| 36 | pre_data_process_func = trainer.schedule.data_process_func |
| 37 | prev_num_microbatches = trainer.schedule.num_microbatches |
| 38 | prev_tensor_shape = trainer.schedule.tensor_shape |
| 39 | prev_metric_hooks = trainer.schedule._hooks |
| 40 | try: |
| 41 | trainer.schedule.data_process_func = None |
| 42 | trainer.schedule.num_microbatches = num_microbatches |
| 43 | trainer.schedule.tensor_shape = tensor_shape |
| 44 | trainer.schedule._hooks = metric_hook_list |
| 45 | yield |
| 46 | finally: |
| 47 | trainer.schedule.data_process_func = pre_data_process_func |
| 48 | trainer.schedule.num_microbatches = prev_num_microbatches |
| 49 | trainer.schedule.tensor_shape = prev_tensor_shape |
| 50 | trainer.schedule._hooks = prev_metric_hooks |
| 51 | |
| 52 | |
| 53 | @contextmanager |
no test coverage detected