MCPcopy
hub / github.com/lm-sys/FastChat / preprocess

Function preprocess

fastchat/train/train.py:92–177  ·  view source on GitHub ↗
(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
)

Source from the content-addressed store, hash-verified

90
91
92def preprocess(
93 sources,
94 tokenizer: transformers.PreTrainedTokenizer,
95) -> Dict:
96 conv = get_conversation_template("vicuna")
97 roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
98
99 # Apply prompt templates
100 conversations = []
101 for i, source in enumerate(sources):
102 if roles[source[0]["from"]] != conv.roles[0]:
103 # Skip the first one if it is not from human
104 source = source[1:]
105
106 conv.messages = []
107 for j, sentence in enumerate(source):
108 role = roles[sentence["from"]]
109 assert role == conv.roles[j % 2], f"{i}"
110 conv.append_message(role, sentence["value"])
111 conversations.append(conv.get_prompt())
112
113 # Tokenize conversations
114 input_ids = tokenizer(
115 conversations,
116 return_tensors="pt",
117 padding="max_length",
118 max_length=tokenizer.model_max_length,
119 truncation=True,
120 ).input_ids
121 targets = input_ids.clone()
122
123 assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
124
125 # Mask targets. Only compute loss on the assistant outputs.
126 sep = conv.sep + conv.roles[1] + ": "
127 for conversation, target in zip(conversations, targets):
128 total_len = int(target.ne(tokenizer.pad_token_id).sum())
129
130 turns = conversation.split(conv.sep2)
131 cur_len = 1
132 target[:cur_len] = IGNORE_TOKEN_ID
133 for i, turn in enumerate(turns):
134 if turn == "":
135 break
136 turn_len = len(tokenizer(turn).input_ids)
137
138 parts = turn.split(sep)
139 if len(parts) != 2:
140 break
141 parts[0] += sep
142 # "-2" is hardcoded for the Llama tokenizer to make the offset correct.
143 instruction_len = len(tokenizer(parts[0]).input_ids) - 2
144
145 if i != 0 and not tokenizer.legacy:
146 # The legacy and non-legacy modes handle special tokens differently
147 instruction_len -= 1
148
149 # Ignore the user instructions

Callers 2

__init__Method · 0.70
__getitem__Method · 0.70

Calls 4

append_messageMethod · 0.80
get_promptMethod · 0.80
rank0_printFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…