(
M: int,
Q_HEAD_NUM: int,
K_HEAD_NUM: int,
HEAD_DIM: int,
dtype: torch.dtype,
test_count: int,
test_configs,
queue,
)
| 76 | |
| 77 | |
| 78 | def worker( |
| 79 | M: int, |
| 80 | Q_HEAD_NUM: int, |
| 81 | K_HEAD_NUM: int, |
| 82 | HEAD_DIM: int, |
| 83 | dtype: torch.dtype, |
| 84 | test_count: int, |
| 85 | test_configs, |
| 86 | queue, |
| 87 | ): |
| 88 | dog = Watchdog(timeout=10) |
| 89 | dog.start() |
| 90 | try: |
| 91 | for index in range(len(test_configs)): |
| 92 | cost_time = test_kernel( |
| 93 | M=M, |
| 94 | Q_HEAD_NUM=Q_HEAD_NUM, |
| 95 | K_HEAD_NUM=K_HEAD_NUM, |
| 96 | HEAD_DIM=HEAD_DIM, |
| 97 | dtype=dtype, |
| 98 | test_count=test_count, |
| 99 | **test_configs[index], |
| 100 | ) |
| 101 | dog.heartbeat() |
| 102 | queue.put(cost_time) # Put result in queue |
| 103 | |
| 104 | except Exception as ex: |
| 105 | logger.error(str(ex)) |
| 106 | logger.exception(str(ex)) |
| 107 | import sys |
| 108 | |
| 109 | sys.exit(-1) |
| 110 | pass |
| 111 | |
| 112 | |
| 113 | def get_test_configs(split_id, split_count): |
nothing calls this directly
no test coverage detected