MCPcopy
hub / github.com/zai-org/ChatGLM3 / HFClient

Class HFClient

composite_demo/client.py:125–198  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

123
124
125class HFClient(Client):
126 def __init__(self, model_path: str, tokenizer_path: str, pt_checkpoint: str = None):
127 self.model_path = model_path
128 self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
129
130 if pt_checkpoint is not None and os.path.exists(pt_checkpoint):
131 config = AutoConfig.from_pretrained(
132 model_path,
133 trust_remote_code=True,
134 pre_seq_len=PRE_SEQ_LEN
135 )
136 self.model = AutoModel.from_pretrained(
137 model_path,
138 trust_remote_code=True,
139 config=config,
140 device_map="auto").eval()
141 # add .quantize(bits=4, device="cuda").cuda() before .eval() and remove device_map="auto" to use int4 model
142 # must use cuda to load int4 model
143 prefix_state_dict = torch.load(os.path.join(pt_checkpoint, "pytorch_model.bin"))
144 new_prefix_state_dict = {}
145 for k, v in prefix_state_dict.items():
146 if k.startswith("transformer.prefix_encoder."):
147 new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
148 print("Loaded from pt checkpoints", new_prefix_state_dict.keys())
149 self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
150 else:
151 self.model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True, device_map="auto").eval()
152 # add .quantize(bits=4, device="cuda").cuda() before .eval() and remove device_map="auto" to use int4 model
153 # must use cuda to load int4 model
154
155 def generate_stream(
156 self,
157 system: str | None,
158 tools: list[dict] | None,
159 history: list[Conversation],
160 **parameters: Any
161 ) -> Iterable[TextGenerationStreamResponse]:
162 chat_history = [{
163 'role': 'system',
164 'content': system if not tools else TOOL_PROMPT,
165 }]
166
167 if tools:
168 chat_history[0]['tools'] = tools
169
170 for conversation in history[:-1]:
171 chat_history.append({
172 'role': str(conversation.role).removeprefix('<|').removesuffix('|>'),
173 'content': conversation.content,
174 })
175
176 query = history[-1].content
177 role = str(history[-1].role).removeprefix('<|').removesuffix('|>')
178 text = ''
179 for new_text, _ in stream_chat(
180 self.model,
181 self.tokenizer,
182 query,

Callers 1

get_clientFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected