(self, element)
| 210 | drop_inputs: bool = True |
| 211 | |
| 212 | def map(self, element): |
| 213 | prompt = kd.kontext.get_by_path(element, self.in_prompt) |
| 214 | chosen = kd.kontext.get_by_path(element, self.in_chosen) |
| 215 | rejected = kd.kontext.get_by_path(element, self.in_rejected) |
| 216 | |
| 217 | # Some datasets (TFDS) returns `bytes` instead of `str`, so decode them. |
| 218 | prompt = _decode_bytes(prompt) |
| 219 | chosen = _decode_bytes(chosen) |
| 220 | rejected = _decode_bytes(rejected) |
| 221 | |
| 222 | # Format the input to match the expected dialog template. |
| 223 | # TODO(epot): Move this in a separate FormatDialog transform. |
| 224 | prompt = _template.PROMPT.format(prompt) |
| 225 | chosen = _template.ANSWER.format(chosen) |
| 226 | rejected = _template.ANSWER.format(rejected) |
| 227 | |
| 228 | # Tokenize the input and the responses. |
| 229 | # Note: Input should start with begin-of-sequence token. |
| 230 | prompt = self.tokenizer.encode(prompt, add_bos=True) |
| 231 | chosen = self.tokenizer.encode(chosen) |
| 232 | rejected = self.tokenizer.encode(rejected) |
| 233 | |
| 234 | next_token_chosen = _functional.make_seq2seq_fields( |
| 235 | prompt=prompt, |
| 236 | response=chosen, |
| 237 | ) |
| 238 | next_token_rejected = _functional.make_seq2seq_fields( |
| 239 | prompt=prompt, |
| 240 | response=rejected, |
| 241 | ) |
| 242 | |
| 243 | # Add padding. |
| 244 | (next_token_chosen, next_token_rejected) = _functional.pad( |
| 245 | (next_token_chosen, next_token_rejected), |
| 246 | max_length=self.max_length, |
| 247 | truncate=self.truncate, |
| 248 | ) |
| 249 | |
| 250 | # Stack the input and target. |
| 251 | out = jax.tree.map( |
| 252 | lambda x, y: np.stack([x, y], axis=0), |
| 253 | next_token_chosen, |
| 254 | next_token_rejected, |
| 255 | ) |
| 256 | |
| 257 | # Add the fields to the output `dict`. |
| 258 | # Equivalent to `element[self.out_input] = ...` |
| 259 | kd.kontext.set_by_path(element, self.out_tokens, out.input) |
| 260 | kd.kontext.set_by_path(element, self.out_targets, out.target) |
| 261 | kd.kontext.set_by_path(element, self.out_mask, out.target_mask) |
| 262 | |
| 263 | # TODO(epot): Supports nested drop |
| 264 | if self.drop_inputs: |
| 265 | del element[self.in_prompt] |
| 266 | del element[self.in_chosen] |
| 267 | del element[self.in_rejected] |
| 268 | |
| 269 | return element |
nothing calls this directly
no test coverage detected