(model_name, seg_files, device="cuda:0", hub="ms")
| 110 | |
| 111 | |
| 112 | def run_vllm(model_name, seg_files, device="cuda:0", hub="ms"): |
| 113 | if "Fun-ASR-Nano" in model_name: |
| 114 | from funasr.models.fun_asr_nano.inference_vllm import FunASRNanoVLLM |
| 115 | engine = FunASRNanoVLLM.from_pretrained( |
| 116 | model=model_name, hub=hub, device=device, dtype="bf16", |
| 117 | max_model_len=4096, gpu_memory_utilization=0.5) |
| 118 | engine.generate(inputs=[seg_files[0]], language="中文") # warmup |
| 119 | t0 = time.perf_counter() |
| 120 | results = engine.generate(inputs=seg_files, language="中文", max_new_tokens=500) |
| 121 | t1 = time.perf_counter() |
| 122 | texts = [r["text"] for r in results] |
| 123 | |
| 124 | elif "GLM-ASR" in model_name: |
| 125 | from funasr.models.glm_asr.inference_vllm import GLMASRVLLMEngine |
| 126 | engine = GLMASRVLLMEngine.from_pretrained( |
| 127 | model=model_name, hub=hub, device=device, dtype="bf16", |
| 128 | gpu_memory_utilization=0.4, max_model_len=4096) |
| 129 | engine.generate(inputs=[seg_files[0]]) # warmup |
| 130 | t0 = time.perf_counter() |
| 131 | results = engine.generate(inputs=seg_files, max_new_tokens=500) |
| 132 | t1 = time.perf_counter() |
| 133 | texts = [r["text"] for r in results] |
| 134 | |
| 135 | else: |
| 136 | from funasr.auto.auto_model_vllm import AutoModelVLLM |
| 137 | engine = AutoModelVLLM(model=model_name, hub=hub, device=device) |
| 138 | engine.generate(inputs=[seg_files[0]]) |
| 139 | t0 = time.perf_counter() |
| 140 | results = engine.generate(inputs=seg_files, max_new_tokens=500) |
| 141 | t1 = time.perf_counter() |
| 142 | texts = [r["text"] for r in results] |
| 143 | |
| 144 | return t1 - t0, texts |
| 145 | |
| 146 | |
| 147 | if __name__ == '__main__': |
no test coverage detected
searching dependent graphs…