(
self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0
)
| 38 | return out |
| 39 | |
| 40 | def generate( |
| 41 | self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0 |
| 42 | ): |
| 43 | # This function is used by fastchat.llm_judge. |
| 44 | # Because RWKV does not support huggingface generation API, |
| 45 | # we reuse fastchat.serve.inference.generate_stream as a workaround. |
| 46 | from transformers import AutoTokenizer |
| 47 | |
| 48 | from fastchat.serve.inference import generate_stream |
| 49 | from fastchat.conversation import get_conv_template |
| 50 | |
| 51 | if self.tokenizer is None: |
| 52 | self.tokenizer = AutoTokenizer.from_pretrained( |
| 53 | "EleutherAI/pythia-160m", use_fast=True |
| 54 | ) |
| 55 | prompt = self.tokenizer.decode(input_ids[0].tolist()) |
| 56 | conv = get_conv_template("rwkv") |
| 57 | |
| 58 | gen_params = { |
| 59 | "model": self.model_path, |
| 60 | "prompt": prompt, |
| 61 | "temperature": temperature, |
| 62 | "repetition_penalty": repetition_penalty, |
| 63 | "max_new_tokens": max_new_tokens, |
| 64 | "stop": conv.stop_str, |
| 65 | "stop_token_ids": conv.stop_token_ids, |
| 66 | "echo": False, |
| 67 | } |
| 68 | res_iter = generate_stream(self, self.tokenizer, gen_params, "cuda") |
| 69 | |
| 70 | for res in res_iter: |
| 71 | pass |
| 72 | |
| 73 | output = res["text"] |
| 74 | output_ids = self.tokenizer.encode(output) |
| 75 | |
| 76 | return [input_ids[0].tolist() + output_ids] |
no test coverage detected