| 97 | sampling: bool = False |
| 98 | |
| 99 | def map(self, element): |
| 100 | # Deep-copy to avoid mutating the input. |
| 101 | element = etree.copy(element) |
| 102 | |
| 103 | # Extract the values from the `dict` example. |
| 104 | # `kontext.get_by_path(element, self.in_prompt)` is equivalent to |
| 105 | # `element[self.in_prompt]`, but supports nested dicts and dataclasses. |
| 106 | prompt = kd.kontext.get_by_path(element, self.in_prompt) |
| 107 | response = kd.kontext.get_by_path(element, self.in_response) |
| 108 | |
| 109 | # TODO(epot): Supports nested drop |
| 110 | if self.drop_inputs: |
| 111 | del element[self.in_prompt] |
| 112 | del element[self.in_response] |
| 113 | |
| 114 | # Some datasets (TFDS) returns `bytes` instead of `str`, so decode them. |
| 115 | prompt = _decode_bytes(prompt) |
| 116 | response = _decode_bytes(response) |
| 117 | |
| 118 | # Format the input to match the expected dialog template. |
| 119 | # TODO(epot): Add a `template` protocol to allow customizing this. |
| 120 | prompt = _template.PROMPT.format(prompt) |
| 121 | response = _template.ANSWER.format(response) |
| 122 | |
| 123 | # For sampling, we don't need to tokenize the input. |
| 124 | if self.sampling: |
| 125 | kd.kontext.set_by_path(element, self.out_input, prompt) |
| 126 | kd.kontext.set_by_path(element, self.out_target, response) |
| 127 | return element |
| 128 | |
| 129 | # Tokenize the input and the responses. |
| 130 | prompt = self.tokenizer.encode(prompt, add_bos=True) |
| 131 | response = self.tokenizer.encode(response) |
| 132 | |
| 133 | # Create the model inputs/targets/loss_mask. |
| 134 | out = _functional.make_seq2seq_fields( |
| 135 | prompt=prompt, |
| 136 | response=response, |
| 137 | ) |
| 138 | |
| 139 | # Add padding. |
| 140 | out = _functional.pad( |
| 141 | out, |
| 142 | max_length=self.max_length, |
| 143 | truncate=self.truncate, |
| 144 | ) |
| 145 | |
| 146 | # For shape compatibility with the loss |
| 147 | target = einops.rearrange(out.target, "... -> ... 1") |
| 148 | target_mask = einops.rearrange(out.target_mask, "... -> ... 1") |
| 149 | |
| 150 | # Add the fields to the output `dict`. |
| 151 | # Equivalent to `element[self.out_input] = ...` |
| 152 | kd.kontext.set_by_path(element, self.out_input, out.input) |
| 153 | kd.kontext.set_by_path(element, self.out_target, target) |
| 154 | kd.kontext.set_by_path(element, self.out_target_mask, target_mask) |
| 155 | return element |
| 156 | |