MCPcopy
hub / github.com/apple/ml-mgie / preprocess_mpt

Function preprocess_mpt

mgie_train.py:373–442  ·  view source on GitHub ↗
(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
)

Source from the content-addressed store, hash-verified

371 )
372
373def preprocess_mpt(
374 sources,
375 tokenizer: transformers.PreTrainedTokenizer,
376) -> Dict:
377 conv = conversation_lib.default_conversation.copy()
378 roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
379
380 # Apply prompt templates
381 conversations = []
382 for i, source in enumerate(sources):
383 if roles[source[0]["from"]] != conv.roles[0]:
384 # Skip the first one if it is not from human
385 source = source[1:]
386
387 conv.messages = []
388 for j, sentence in enumerate(source):
389 role = roles[sentence["from"]]
390 assert role == conv.roles[j % 2], f"{i}"
391 conv.append_message(role, sentence["value"])
392 conversations.append(conv.get_prompt())
393
394 # Tokenize conversations
395 input_ids = tokenizer(
396 conversations,
397 return_tensors="pt",
398 padding="longest",
399 max_length=tokenizer.model_max_length,
400 truncation=True,
401 ).input_ids
402 targets = input_ids.clone()
403 assert conv.sep_style == conversation_lib.SeparatorStyle.MPT
404
405 # Mask targets
406 sep = conv.sep + conv.roles[1]
407 for conversation, target in zip(conversations, targets):
408 total_len = int(target.ne(tokenizer.pad_token_id).sum())
409
410 rounds = conversation.split(conv.sep)
411 re_rounds = [conv.sep.join(rounds[:3])] # system + user + gpt
412 for conv_idx in range(3, len(rounds), 2):
413 re_rounds.append(conv.sep.join(rounds[conv_idx:conv_idx+2])) # user + gpt
414 cur_len = 0
415 target[:cur_len] = IGNORE_INDEX
416 for i, rou in enumerate(re_rounds):
417 if rou == "":
418 break
419
420 parts = rou.split(sep)
421 if len(parts) != 2:
422 break
423 parts[0] += sep
424 round_len = len(tokenizer(rou).input_ids) + len(tokenizer(conv.sep).input_ids)
425 instruction_len = len(tokenizer(parts[0]).input_ids)
426 target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
427
428 cur_len += round_len
429 target[cur_len:] = IGNORE_INDEX
430

Callers 1

preprocessFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected