| 93 | |
| 94 | |
| 95 | def tuning_configs(device_id, device_count, **configs): |
| 96 | test_func = configs.pop("test_func") |
| 97 | test_kwargs = configs.pop("test_func_args") |
| 98 | get_test_configs_func = configs.pop("get_test_configs_func") |
| 99 | |
| 100 | os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) |
| 101 | best_config, best_cost_time = None, float("inf") |
| 102 | queue = mp.Queue() |
| 103 | test_configs_list = [] |
| 104 | for t_config in get_test_configs_func(device_id, device_count, **test_kwargs): |
| 105 | test_configs_list.append(t_config) |
| 106 | if len(test_configs_list) < 64: |
| 107 | continue |
| 108 | |
| 109 | p = mp.Process( |
| 110 | target=workder, |
| 111 | args=(test_func, test_configs_list, test_kwargs, queue), |
| 112 | ) |
| 113 | p.start() |
| 114 | p.join() |
| 115 | |
| 116 | while len(test_configs_list) > 0: |
| 117 | try: |
| 118 | cost_time = queue.get_nowait() |
| 119 | logger.info(f"get {test_configs_list[0]} cost_time: {cost_time}") |
| 120 | if cost_time < best_cost_time: |
| 121 | best_config = test_configs_list[0] |
| 122 | best_cost_time = cost_time |
| 123 | logger.info(f"current best: {best_config}, cost_time: {best_cost_time}") |
| 124 | del test_configs_list[0] |
| 125 | except: |
| 126 | logger.info(f"current best: {best_config}, cost_time: {best_cost_time}") |
| 127 | del test_configs_list[0] |
| 128 | break |
| 129 | |
| 130 | while len(test_configs_list) > 0: |
| 131 | p = mp.Process( |
| 132 | target=workder, |
| 133 | args=(test_func, test_configs_list, test_kwargs, queue), |
| 134 | ) |
| 135 | p.start() |
| 136 | p.join() |
| 137 | |
| 138 | while len(test_configs_list) > 0: |
| 139 | try: |
| 140 | cost_time = queue.get_nowait() |
| 141 | logger.info(f"get {test_configs_list[0]} cost_time: {cost_time}") |
| 142 | if cost_time < best_cost_time: |
| 143 | best_config = test_configs_list[0] |
| 144 | best_cost_time = cost_time |
| 145 | logger.info(f"current best: {best_config}, cost_time: {best_cost_time}") |
| 146 | del test_configs_list[0] |
| 147 | except: |
| 148 | logger.info(f"current best: {best_config}, cost_time: {best_cost_time}") |
| 149 | del test_configs_list[0] |
| 150 | break |
| 151 | |
| 152 | logger.info(f"Final best config: {best_config}, cost_time: {best_cost_time}") |