Test main function of RequestHandler
()
| 54 | |
| 55 | |
| 56 | def check_request_handler(): |
| 57 | """ |
| 58 | Test main function of RequestHandler |
| 59 | """ |
| 60 | inference_config = InferenceConfig( |
| 61 | max_input_len=10, |
| 62 | max_output_len=10, |
| 63 | block_size=8, |
| 64 | ) |
| 65 | model_config = LlamaConfig( |
| 66 | hidden_size=32, |
| 67 | num_hidden_layers=2, |
| 68 | num_attention_heads=4, |
| 69 | ) |
| 70 | request_handler = RequestHandler(inference_config, model_config) |
| 71 | seq1 = Sequence( |
| 72 | request_id=1, |
| 73 | prompt="abc", |
| 74 | input_token_id=[1, 2, 3, 4, 5], |
| 75 | block_size=16, |
| 76 | eos_token_id=0, |
| 77 | pad_token_id=0, |
| 78 | sample_params=None, |
| 79 | ) |
| 80 | request_handler.add_sequence(seq1) |
| 81 | # the priority should be 1 |
| 82 | assert request_handler.waiting_list[1][0] == seq1 |
| 83 | assert request_handler._has_waiting() |
| 84 | |
| 85 | request_handler.abort_sequence(seq1.request_id) |
| 86 | assert not request_handler._has_waiting() |
| 87 | seq1.status = RequestStatus.WAITING |
| 88 | request_handler.add_sequence(seq1) |
| 89 | request_handler.schedule() |
| 90 | |
| 91 | |
| 92 | def run_dist(rank, world_size, port): |
no test coverage detected
searching dependent graphs…