| 123 | |
| 124 | |
| 125 | class HFClient(Client): |
| 126 | def __init__(self, model_path: str, tokenizer_path: str, pt_checkpoint: str = None): |
| 127 | self.model_path = model_path |
| 128 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) |
| 129 | |
| 130 | if pt_checkpoint is not None and os.path.exists(pt_checkpoint): |
| 131 | config = AutoConfig.from_pretrained( |
| 132 | model_path, |
| 133 | trust_remote_code=True, |
| 134 | pre_seq_len=PRE_SEQ_LEN |
| 135 | ) |
| 136 | self.model = AutoModel.from_pretrained( |
| 137 | model_path, |
| 138 | trust_remote_code=True, |
| 139 | config=config, |
| 140 | device_map="auto").eval() |
| 141 | # add .quantize(bits=4, device="cuda").cuda() before .eval() and remove device_map="auto" to use int4 model |
| 142 | # must use cuda to load int4 model |
| 143 | prefix_state_dict = torch.load(os.path.join(pt_checkpoint, "pytorch_model.bin")) |
| 144 | new_prefix_state_dict = {} |
| 145 | for k, v in prefix_state_dict.items(): |
| 146 | if k.startswith("transformer.prefix_encoder."): |
| 147 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v |
| 148 | print("Loaded from pt checkpoints", new_prefix_state_dict.keys()) |
| 149 | self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) |
| 150 | else: |
| 151 | self.model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval() |
| 152 | # add .quantize(bits=4, device="cuda").cuda() before .eval() and remove device_map="auto" to use int4 model |
| 153 | # must use cuda to load int4 model |
| 154 | |
| 155 | def generate_stream( |
| 156 | self, |
| 157 | system: str | None, |
| 158 | tools: list[dict] | None, |
| 159 | history: list[Conversation], |
| 160 | **parameters: Any |
| 161 | ) -> Iterable[TextGenerationStreamResponse]: |
| 162 | chat_history = [{ |
| 163 | 'role': 'system', |
| 164 | 'content': system if not tools else TOOL_PROMPT, |
| 165 | }] |
| 166 | |
| 167 | if tools: |
| 168 | chat_history[0]['tools'] = tools |
| 169 | |
| 170 | for conversation in history[:-1]: |
| 171 | chat_history.append({ |
| 172 | 'role': str(conversation.role).removeprefix('<|').removesuffix('|>'), |
| 173 | 'content': conversation.content, |
| 174 | }) |
| 175 | |
| 176 | query = history[-1].content |
| 177 | role = str(history[-1].role).removeprefix('<|').removesuffix('|>') |
| 178 | text = '' |
| 179 | for new_text, _ in stream_chat( |
| 180 | self.model, |
| 181 | self.tokenizer, |
| 182 | query, |