MCPcopy
hub / github.com/hpcaitech/ColossalAI / _run_engine

Function _run_engine

tests/test_infer/test_models/test_custom_model.py:98–149  ·  view source on GitHub ↗
(model, use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None)

Source from the content-addressed store, hash-verified

96
97
98def _run_engine(model, use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None, policy=None):
99 setup_seed(20)
100 model_config = MODEL_MAP[model]
101 model_name_or_path = model_config["model_name_or_path"]
102 tokenizer = model_config["tokenizer"].from_pretrained(model_name_or_path, use_fast=False, trust_remote_code=True)
103 model = model_config["model"].from_pretrained(model_name_or_path, trust_remote_code=True).half().cuda()
104 model = model.eval()
105
106 inputs = [
107 "Introduce some landmarks in Paris:",
108 ]
109
110 output_len = 38
111
112 if do_sample:
113 top_p = 0.5
114 top_k = 50
115 else:
116 top_p = None
117 top_k = None
118
119 if use_engine:
120 inference_config = InferenceConfig(
121 max_output_len=output_len,
122 prompt_template=prompt_template,
123 use_cuda_kernel=use_cuda_kernel,
124 tp_size=dist.get_world_size(),
125 )
126 inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
127 assert inference_engine.generation_config.max_new_tokens == output_len
128 inference_engine.add_request(prompts=inputs)
129 assert inference_engine.request_handler._has_waiting()
130 generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k, max_new_tokens=output_len)
131 outputs = inference_engine.generate(generation_config=generation_config)
132 else:
133 if prompt_template:
134 # apply prompt template
135 inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
136 tokenizer.pad_token = tokenizer.eos_token
137 tokenizer.pad_token_id = tokenizer.eos_token_id
138 inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
139 inputs = inputs.cuda()
140 generation_config = GenerationConfig(
141 do_sample=do_sample,
142 top_p=top_p,
143 top_k=top_k,
144 pad_token_id=tokenizer.pad_token_id,
145 max_new_tokens=output_len,
146 )
147 outputs = model.generate(inputs, generation_config=generation_config)
148 outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
149 return outputs
150
151
152def setup_seed(seed):

Callers

nothing calls this directly

Calls 12

add_requestMethod · 0.95
generateMethod · 0.95
InferenceConfigClass · 0.90
InferenceEngineClass · 0.90
halfMethod · 0.80
_has_waitingMethod · 0.80
setup_seedFunction · 0.70
from_pretrainedMethod · 0.45
cudaMethod · 0.45
evalMethod · 0.45
get_world_sizeMethod · 0.45
generateMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…