MCPcopy
hub / github.com/lm-sys/FastChat / preprocess

Function preprocess

fastchat/train/train_yuan2.py:100–313  ·  view source on GitHub ↗
(
    sources,
    tokenizer: transformers.PreTrainedTokenizer,
    data_args,
)

Source from the content-addressed store, hash-verified

98
99
100def preprocess(
101 sources,
102 tokenizer: transformers.PreTrainedTokenizer,
103 data_args,
104) -> Dict:
105 conv = get_conversation_template("yuan2") # wpf
106 roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
107
108 # Apply prompt templates
109 conversations = []
110 for i, source in enumerate(sources):
111 if roles[source[0]["from"]] != conv.roles[0]:
112 # Skip the first one if it is not from human
113 source = source[1:]
114
115 conv.messages = []
116 for j, sentence in enumerate(source):
117 role = roles[sentence["from"]]
118 assert role == conv.roles[j % 2], f"{i}"
119 conv.append_message(role, sentence["value"])
120 conversations.append(conv.get_prompt())
121 if data_args.last_response_loss:
122 a = conversations[0].replace("<sep>", "<eod>")
123 a = right_replace(a, "<n>", "<sep>")
124 # a=right_replace(a,"<n>","\n",max=20)
125 conversations[0] = a
126 if data_args.split_example_loss:
127 a = conversations[0].replace("<sep>", "")
128 a = a.split("<n>")
129 for i in range(int(len(a) / 2)):
130 if i == 0:
131 conversations[i] = ""
132 if i != 0:
133 conversations.append("")
134 for j in range(i * 2):
135 conversations[i] = conversations[i] + a[j] + "<n>"
136 conversations[i] = (
137 conversations[i] + a[i * 2] + "<sep>" + a[i * 2 + 1] + "<eod>"
138 )
139
140 if data_args.efficient_loss:
141 a = conversations[0].replace("<sep>", "<eod>")
142 conversations[0] = a
143
144 print(conversations)
145
146 # Tokenize conversations
147 input_ids = tokenizer(
148 conversations,
149 return_tensors="pt",
150 padding="max_length",
151 max_length=tokenizer.model_max_length,
152 truncation=True,
153 ).input_ids
154 targets = input_ids.clone()
155
156 # assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO #wpf
157 # Mask targets. Only compute loss on the assistant outputs.

Callers 2

__init__Method · 0.70
__getitem__Method · 0.70

Calls 5

right_replaceFunction · 0.85
append_messageMethod · 0.80
get_promptMethod · 0.80
rank0_printFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…