MCPcopy
hub / github.com/LAION-AI/Open-Assistant / handle_work_request

Function handle_work_request

inference/worker/work.py:103–187  ·  view source on GitHub ↗
(
    ws: websocket.WebSocket,
    tokenizer: transformers.PreTrainedTokenizer,
    work_request: inference.WorkRequest,
    worker_config: inference.WorkerConfig,
)

Source from the content-addressed store, hash-verified

101
102
103def handle_work_request(
104 ws: websocket.WebSocket,
105 tokenizer: transformers.PreTrainedTokenizer,
106 work_request: inference.WorkRequest,
107 worker_config: inference.WorkerConfig,
108):
109 prompt, parameters = make_prompt_and_parameters(tokenizer=tokenizer, work_request=work_request)
110 logger.debug(f"Prompt: {prompt}")
111
112 model_config = worker_config.model_config
113
114 # Only send safety request if work request safety level is not 0
115 if settings.enable_safety and work_request.safety_parameters.level:
116 safety_request = inference.SafetyRequest(inputs=prompt, parameters=work_request.safety_parameters)
117 safety_response = get_safety_server_response(safety_request)
118 prompt = get_safety_opinion(prompt, safety_response.outputs, work_request.safety_parameters.level)
119 logger.debug(f"Safe prompt: {prompt}")
120
121 stream_response = None
122 token_buffer = utils.TokenBuffer(stop_sequences=parameters.stop)
123 if model_config.is_lorem:
124 stream_events = utils.lorem_events(parameters.seed)
125 else:
126 prompt = truncate_prompt(tokenizer, worker_config, parameters, prompt)
127 stream_request = interface.GenerateStreamRequest(
128 inputs=prompt,
129 parameters=parameters,
130 )
131 stream_events = get_inference_server_stream_events(stream_request)
132
133 generated_ids = []
134 decoded_text = ""
135 for stream_response in stream_events:
136 if stream_response.is_error:
137 logger.error(f"Error from inference server: {stream_response.error}")
138 utils.send_response(
139 ws,
140 inference.ErrorResponse(
141 request_id=work_request.id,
142 error=stream_response.error,
143 metrics=inference.WorkerMetricsInfo(),
144 ),
145 )
146 raise RuntimeError(f"Error from inference server: {stream_response.error}")
147 token = stream_response.token
148
149 if model_config.is_llama:
150 generated_ids.append(token.id)
151 try:
152 with tokenizer_lock:
153 text = tokenizer.decode(generated_ids, skip_special_tokens=True)
154 new_text = text[len(decoded_text) :]
155 if not decoded_text:
156 new_text = new_text.lstrip()
157 except Exception:
158 text = decoded_text
159 new_text = ""
160 token.text = new_text

Callers

nothing calls this directly

Calls 8

addMethod · 0.95
finishMethod · 0.95
get_safety_opinionFunction · 0.85
truncate_promptFunction · 0.85
to_token_responseMethod · 0.80

Tested by

no test coverage detected