| 51 | self.max_batch_size = 0 # model_config["max_batch_size"] |
| 52 | |
| 53 | def execute(self, requests): |
| 54 | # TODO: don't just loop over requests. batch them up |
| 55 | |
| 56 | responses = [] |
| 57 | |
| 58 | for request in requests: |
| 59 | input_ids_torch = pb2torch(request, "input_ids") |
| 60 | input_lengths_torch = pb2torch(request, "input_lengths") |
| 61 | request_output_len_torch = pb2torch(request, "request_output_len") |
| 62 | |
| 63 | # Attention mask |
| 64 | attention_mask = None |
| 65 | if input_lengths_torch.min() != input_lengths_torch.max(): |
| 66 | attention_mask = torch.zeros(input_ids_torch.shape, dtype=torch.long) |
| 67 | for i, l in enumerate(input_lengths_torch): |
| 68 | attention_mask[i, :l] = 1 |
| 69 | |
| 70 | # Output length |
| 71 | max_new_tokens = request_output_len_torch[0][0] |
| 72 | |
| 73 | top_k = pb_utils.get_input_tensor_by_name(request, "runtime_top_k").as_numpy().tolist()[0] |
| 74 | top_p = pb_utils.get_input_tensor_by_name(request, "runtime_top_p").as_numpy().tolist()[0] |
| 75 | temperature = pb_utils.get_input_tensor_by_name(request, "temperature").as_numpy().tolist()[0] |
| 76 | # n_samples = pb_utils.get_input_tensor_by_name(request, "n") |
| 77 | n_samples = 1 # TODO: client doesn't send this yet. instead it duplicates the request n times |
| 78 | |
| 79 | # Generate |
| 80 | output_ids = self.model.generate( |
| 81 | input_ids=input_ids_torch, attention_mask=attention_mask, |
| 82 | max_new_tokens=max_new_tokens, do_sample=True, top_k=top_k, top_p=top_p, num_return_sequences=n_samples, |
| 83 | temperature=temperature, |
| 84 | ) |
| 85 | |
| 86 | # client wants batch x beam_width x seq_len and we don't support beam_width yet |
| 87 | output_ids = output_ids.unsqueeze(1) |
| 88 | |
| 89 | # create output tensors |
| 90 | out_tensor_pb = torch2pb("output_ids", output_ids) |
| 91 | |
| 92 | # calculate sequence_length |
| 93 | sequence_length = torch.zeros(output_ids.shape[:2], dtype=torch.int32) |
| 94 | for i in range(output_ids.shape[0]): |
| 95 | sequence_length[i, 0] = torch.sum(output_ids[i, 0] != self.model.config.eos_token_id).item() |
| 96 | sequence_length_pb = torch2pb("sequence_length", sequence_length) |
| 97 | |
| 98 | # create response |
| 99 | response = pb_utils.InferenceResponse([out_tensor_pb, sequence_length_pb]) |
| 100 | responses.append(response) |
| 101 | |
| 102 | return responses |