(args: argparse.Namespace)
| 196 | |
| 197 | |
| 198 | def _run_transformers(args: argparse.Namespace) -> None: |
| 199 | import torch |
| 200 | from torch import distributed as torch_dist |
| 201 | from transformers import AutoModelForCausalLM, AutoTokenizer |
| 202 | |
| 203 | from .model import DFlashDraftModel, dflash_generate |
| 204 | |
| 205 | _check_transformers_model(args.model) |
| 206 | |
| 207 | random.seed(0) |
| 208 | np.random.seed(0) |
| 209 | torch.manual_seed(0) |
| 210 | torch.cuda.manual_seed_all(0) |
| 211 | torch.backends.cudnn.deterministic = True |
| 212 | torch.backends.cudnn.benchmark = False |
| 213 | |
| 214 | _dist_init(torch_dist) |
| 215 | torch.cuda.set_device(_dist_local_rank()) |
| 216 | device = torch.device(f"cuda:{_dist_local_rank()}") |
| 217 | attn_impl = _get_transformers_attn_impl() |
| 218 | |
| 219 | target = AutoModelForCausalLM.from_pretrained( |
| 220 | args.model, attn_implementation=attn_impl, dtype=torch.bfloat16, |
| 221 | ).to(device).eval() |
| 222 | |
| 223 | draft_model = DFlashDraftModel.from_pretrained( |
| 224 | args.draft_model, attn_implementation=attn_impl, dtype=torch.bfloat16, |
| 225 | ).to(device).eval() |
| 226 | |
| 227 | block_size = args.block_size if args.block_size is not None else draft_model.block_size |
| 228 | tokenizer = AutoTokenizer.from_pretrained(args.model) |
| 229 | dataset = load_and_process_dataset(args.dataset) |
| 230 | |
| 231 | dataset = _limit_dataset(dataset, args.max_samples) |
| 232 | |
| 233 | responses = [] |
| 234 | indices = range(_dist_rank(), len(dataset), _dist_size()) |
| 235 | for idx in tqdm(indices, disable=not _dist_is_main()): |
| 236 | instance = dataset[idx] |
| 237 | messages = [] |
| 238 | for user_content in instance["turns"]: |
| 239 | messages.append({"role": "user", "content": user_content}) |
| 240 | input_text = _apply_chat_template(tokenizer, messages, args.enable_thinking) |
| 241 | input_ids = tokenizer.encode(input_text, return_tensors="pt").to(target.device) |
| 242 | |
| 243 | response = {} |
| 244 | for bs in [1, block_size]: |
| 245 | response[bs] = dflash_generate( |
| 246 | draft_model, |
| 247 | target=target, |
| 248 | input_ids=input_ids, |
| 249 | max_new_tokens=args.max_new_tokens, |
| 250 | stop_token_ids=[tokenizer.eos_token_id], |
| 251 | temperature=args.temperature, |
| 252 | block_size=bs, |
| 253 | return_stats=True, |
| 254 | ) |
| 255 |
no test coverage detected