(
ws: websocket.WebSocket,
tokenizer: transformers.PreTrainedTokenizer,
work_request: inference.WorkRequest,
worker_config: inference.WorkerConfig,
)
| 101 | |
| 102 | |
| 103 | def handle_work_request( |
| 104 | ws: websocket.WebSocket, |
| 105 | tokenizer: transformers.PreTrainedTokenizer, |
| 106 | work_request: inference.WorkRequest, |
| 107 | worker_config: inference.WorkerConfig, |
| 108 | ): |
| 109 | prompt, parameters = make_prompt_and_parameters(tokenizer=tokenizer, work_request=work_request) |
| 110 | logger.debug(f"Prompt: {prompt}") |
| 111 | |
| 112 | model_config = worker_config.model_config |
| 113 | |
| 114 | # Only send safety request if work request safety level is not 0 |
| 115 | if settings.enable_safety and work_request.safety_parameters.level: |
| 116 | safety_request = inference.SafetyRequest(inputs=prompt, parameters=work_request.safety_parameters) |
| 117 | safety_response = get_safety_server_response(safety_request) |
| 118 | prompt = get_safety_opinion(prompt, safety_response.outputs, work_request.safety_parameters.level) |
| 119 | logger.debug(f"Safe prompt: {prompt}") |
| 120 | |
| 121 | stream_response = None |
| 122 | token_buffer = utils.TokenBuffer(stop_sequences=parameters.stop) |
| 123 | if model_config.is_lorem: |
| 124 | stream_events = utils.lorem_events(parameters.seed) |
| 125 | else: |
| 126 | prompt = truncate_prompt(tokenizer, worker_config, parameters, prompt) |
| 127 | stream_request = interface.GenerateStreamRequest( |
| 128 | inputs=prompt, |
| 129 | parameters=parameters, |
| 130 | ) |
| 131 | stream_events = get_inference_server_stream_events(stream_request) |
| 132 | |
| 133 | generated_ids = [] |
| 134 | decoded_text = "" |
| 135 | for stream_response in stream_events: |
| 136 | if stream_response.is_error: |
| 137 | logger.error(f"Error from inference server: {stream_response.error}") |
| 138 | utils.send_response( |
| 139 | ws, |
| 140 | inference.ErrorResponse( |
| 141 | request_id=work_request.id, |
| 142 | error=stream_response.error, |
| 143 | metrics=inference.WorkerMetricsInfo(), |
| 144 | ), |
| 145 | ) |
| 146 | raise RuntimeError(f"Error from inference server: {stream_response.error}") |
| 147 | token = stream_response.token |
| 148 | |
| 149 | if model_config.is_llama: |
| 150 | generated_ids.append(token.id) |
| 151 | try: |
| 152 | with tokenizer_lock: |
| 153 | text = tokenizer.decode(generated_ids, skip_special_tokens=True) |
| 154 | new_text = text[len(decoded_text) :] |
| 155 | if not decoded_text: |
| 156 | new_text = new_text.lstrip() |
| 157 | except Exception: |
| 158 | text = decoded_text |
| 159 | new_text = "" |
| 160 | token.text = new_text |
nothing calls this directly
no test coverage detected