Dataset for processing single or multi-task data with task-specific tokenization and processing. Args: dataset: Input dataset containing raw data tokenizer: Tokenizer for text processing default_task_data_spec: Default task processing specifications. In the c
| 31 | |
| 32 | # TODO @sahilj handle too-long prompts and masking them out throughout the whole process and renormalizing on loss |
| 33 | class AllTaskProcessedDataset: |
| 34 | """Dataset for processing single or multi-task data with task-specific tokenization and processing. |
| 35 | |
| 36 | Args: |
| 37 | dataset: Input dataset containing raw data |
| 38 | tokenizer: Tokenizer for text processing |
| 39 | default_task_data_spec: Default task processing specifications. |
| 40 | In the case of single-task, this is the spec used for processing all entries. |
| 41 | In the case of multi-task, any values not specified in the task-specific specs will be taken from the default spec. |
| 42 | task_data_processors: Either a single TaskDataProcessFnCallable for single-task, |
| 43 | or a dict mapping task names to (TaskDataSpec, TaskDataProcessFnCallable) for multi-task |
| 44 | max_seq_length: Maximum sequence length for tokenized outputs |
| 45 | """ |
| 46 | |
| 47 | def __init__( |
| 48 | self, |
| 49 | dataset: Dataset | Any, |
| 50 | tokenizer: TokenizerType, |
| 51 | default_task_data_spec: TaskDataSpec, |
| 52 | task_data_processors: ( |
| 53 | dict[str, tuple[TaskDataSpec, TaskDataProcessFnCallable]] |
| 54 | | TaskDataProcessFnCallable |
| 55 | ), |
| 56 | task_data_preprocessors: Optional[ |
| 57 | Union[dict[str, TaskDataPreProcessFnCallable], TaskDataPreProcessFnCallable] |
| 58 | ] = None, |
| 59 | max_seq_length: Optional[int] = None, |
| 60 | ): |
| 61 | self.dataset = dataset |
| 62 | self.tokenizer = tokenizer |
| 63 | # TODO @yukih: will be removed once eval datasets are adapted |
| 64 | self.default_task_data_spec = default_task_data_spec |
| 65 | self.task_data_processors = task_data_processors |
| 66 | self.task_data_preprocessors = task_data_preprocessors |
| 67 | self.max_seq_length = max_seq_length |
| 68 | self._bos_checked = False |
| 69 | |
| 70 | if ( |
| 71 | isinstance(task_data_processors, dict) |
| 72 | and default_task_data_spec is not None |
| 73 | ): |
| 74 | # apply defaults to all task data specs |
| 75 | for _, (task_data_spec, _) in task_data_processors.items(): |
| 76 | task_data_spec.copy_defaults(self.default_task_data_spec) |
| 77 | |
| 78 | def __len__(self) -> int: |
| 79 | return len(self.dataset) |
| 80 | |
| 81 | def encode_single( |
| 82 | self, text: Union[str, list[str]] |
| 83 | ) -> tuple[list[int] | torch.Tensor, int]: |
| 84 | """Takes either a single string or a list of strings that represent multiple turns for the same conversation. |
| 85 | |
| 86 | Returns a single (concatenated) list of tokenized ids and the length of the tokenized ids. |
| 87 | """ |
| 88 | if isinstance(text, str): |
| 89 | text_ids = self.tokenizer.text_to_ids(text) |
| 90 | return text_ids, len(text_ids) |
no outgoing calls