MCPcopy
hub / github.com/z-lab/dflash / _run_transformers

Function _run_transformers

dflash/benchmark.py:198–268  ·  view source on GitHub ↗
(args: argparse.Namespace)

Source from the content-addressed store, hash-verified

196
197
198def _run_transformers(args: argparse.Namespace) -> None:
199 import torch
200 from torch import distributed as torch_dist
201 from transformers import AutoModelForCausalLM, AutoTokenizer
202
203 from .model import DFlashDraftModel, dflash_generate
204
205 _check_transformers_model(args.model)
206
207 random.seed(0)
208 np.random.seed(0)
209 torch.manual_seed(0)
210 torch.cuda.manual_seed_all(0)
211 torch.backends.cudnn.deterministic = True
212 torch.backends.cudnn.benchmark = False
213
214 _dist_init(torch_dist)
215 torch.cuda.set_device(_dist_local_rank())
216 device = torch.device(f"cuda:{_dist_local_rank()}")
217 attn_impl = _get_transformers_attn_impl()
218
219 target = AutoModelForCausalLM.from_pretrained(
220 args.model, attn_implementation=attn_impl, dtype=torch.bfloat16,
221 ).to(device).eval()
222
223 draft_model = DFlashDraftModel.from_pretrained(
224 args.draft_model, attn_implementation=attn_impl, dtype=torch.bfloat16,
225 ).to(device).eval()
226
227 block_size = args.block_size if args.block_size is not None else draft_model.block_size
228 tokenizer = AutoTokenizer.from_pretrained(args.model)
229 dataset = load_and_process_dataset(args.dataset)
230
231 dataset = _limit_dataset(dataset, args.max_samples)
232
233 responses = []
234 indices = range(_dist_rank(), len(dataset), _dist_size())
235 for idx in tqdm(indices, disable=not _dist_is_main()):
236 instance = dataset[idx]
237 messages = []
238 for user_content in instance["turns"]:
239 messages.append({"role": "user", "content": user_content})
240 input_text = _apply_chat_template(tokenizer, messages, args.enable_thinking)
241 input_ids = tokenizer.encode(input_text, return_tensors="pt").to(target.device)
242
243 response = {}
244 for bs in [1, block_size]:
245 response[bs] = dflash_generate(
246 draft_model,
247 target=target,
248 input_ids=input_ids,
249 max_new_tokens=args.max_new_tokens,
250 stop_token_ids=[tokenizer.eos_token_id],
251 temperature=args.temperature,
252 block_size=bs,
253 return_stats=True,
254 )
255

Callers 1

mainFunction · 0.85

Calls 13

_dist_initFunction · 0.85
_dist_local_rankFunction · 0.85
load_and_process_datasetFunction · 0.85
_limit_datasetFunction · 0.85
_dist_rankFunction · 0.85
_dist_sizeFunction · 0.85
_dist_is_mainFunction · 0.85
_apply_chat_templateFunction · 0.85
dflash_generateFunction · 0.85
_dist_gatherFunction · 0.85

Tested by

no test coverage detected