MCPcopy
hub / github.com/lm-sys/FastChat / generate

Method generate

fastchat/model/rwkv_model.py:40–76  ·  view source on GitHub ↗
(
        self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0
    )

Source from the content-addressed store, hash-verified

38 return out
39
40 def generate(
41 self, input_ids, do_sample, temperature, max_new_tokens, repetition_penalty=1.0
42 ):
43 # This function is used by fastchat.llm_judge.
44 # Because RWKV does not support huggingface generation API,
45 # we reuse fastchat.serve.inference.generate_stream as a workaround.
46 from transformers import AutoTokenizer
47
48 from fastchat.serve.inference import generate_stream
49 from fastchat.conversation import get_conv_template
50
51 if self.tokenizer is None:
52 self.tokenizer = AutoTokenizer.from_pretrained(
53 "EleutherAI/pythia-160m", use_fast=True
54 )
55 prompt = self.tokenizer.decode(input_ids[0].tolist())
56 conv = get_conv_template("rwkv")
57
58 gen_params = {
59 "model": self.model_path,
60 "prompt": prompt,
61 "temperature": temperature,
62 "repetition_penalty": repetition_penalty,
63 "max_new_tokens": max_new_tokens,
64 "stop": conv.stop_str,
65 "stop_token_ids": conv.stop_token_ids,
66 "echo": False,
67 }
68 res_iter = generate_stream(self, self.tokenizer, gen_params, "cuda")
69
70 for res in res_iter:
71 pass
72
73 output = res["text"]
74 output_ids = self.tokenizer.encode(output)
75
76 return [input_ids[0].tolist() + output_ids]

Callers 1

load_xft_modelFunction · 0.45

Calls 2

get_conv_templateFunction · 0.90
generate_streamFunction · 0.90

Tested by

no test coverage detected