Specifications for operating and training on data. An example of constructing a :class:`DataSpec` object with a ``batch_transforms`` callable and then using it with :class:`~.Trainer`: .. doctest:: >>> # Construct DataSpec and subtract mean from the batch >>> batch_trans
| 124 | |
| 125 | |
| 126 | class DataSpec: |
| 127 | """Specifications for operating and training on data. |
| 128 | |
| 129 | An example of constructing a :class:`DataSpec` object with a ``batch_transforms`` |
| 130 | callable and then using it with :class:`~.Trainer`: |
| 131 | |
| 132 | .. doctest:: |
| 133 | |
| 134 | >>> # Construct DataSpec and subtract mean from the batch |
| 135 | >>> batch_transform_fn = lambda xs, ys: (xs.sub_(xs.mean()), ys) |
| 136 | >>> train_dspec = DataSpec(train_dataloader, batch_transforms=batch_transform_fn) |
| 137 | >>> # The same function can be used for eval dataloader as well |
| 138 | >>> eval_dspec = DataSpec(eval_dataloader, batch_transforms=batch_transform_fn) |
| 139 | >>> # Use this DataSpec object to construct trainer |
| 140 | >>> trainer = Trainer( |
| 141 | ... model=model, |
| 142 | ... train_dataloader=train_dspec, |
| 143 | ... eval_dataloader=eval_dspec, |
| 144 | ... optimizers=optimizer, |
| 145 | ... max_duration="1ep", |
| 146 | ... ) |
| 147 | |
| 148 | Args: |
| 149 | dataloader (Union[Iterable, torch.utils.data.DataLoader]): The dataloader, which can be any iterable that yields batches. |
| 150 | |
| 151 | num_samples (int, optional): The total number of samples in an epoch, across all ranks. This field is used by |
| 152 | the :class:`.Timestamp` (training progress tracker). If not specified, then ``len(dataloader.dataset)`` is |
| 153 | used (if this property is available). Otherwise, the dataset is assumed to be unsized. |
| 154 | |
| 155 | num_tokens (int, optional): The total number of tokens in an epoch. This field is used by the |
| 156 | :class:`.Timestamp` (training progress tracker). |
| 157 | |
| 158 | batch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the |
| 159 | batch before it is moved onto the device. For example, this function can be used for CPU-based |
| 160 | normalization. It can modify the batch in-place, and it should return the modified batch. If not specified, |
| 161 | the batch is not modified. |
| 162 | |
| 163 | microbatch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the |
| 164 | microbatch before it is moved onto the device. For example, this function can be used for GPU-based |
| 165 | normalization. It can modify the microbatch in-place, and it should return the modified microbatch. If not |
| 166 | specified, the microbatch is not modified. |
| 167 | |
| 168 | split_batch ((Batch, (int | float)) -> Sequence[Batch], optional): Function called by the :class:`.Trainer` to |
| 169 | split a batch (the first parameter) into microbatches of a given size (the second parameter). If |
| 170 | the ``dataloader`` yields batches not of type :class:`torch.Tensor`, Mapping, tuple, or list, then |
| 171 | this function must be specified. |
| 172 | |
| 173 | get_num_samples_in_batch ((Batch) -> Union[int, float], optional): Function that is called by the :class:`.Trainer` |
| 174 | to get the number of samples in the provided batch. |
| 175 | |
| 176 | By default, if the batch contains tensors that all have the same 0th dim, then the value of the 0th dim will |
| 177 | be returned. If the batch contains tensors where the 0th dim differ, then this function must be specified. |
| 178 | |
| 179 | get_num_tokens_in_batch ((Batch) -> int, optional): Function that is called by the :class:`.Trainer` to |
| 180 | get the number of tokens in the provided batch. |
| 181 | |
| 182 | By default, it checks for HuggingFace-style dictionary batches with ``input_ids``, and then checks ``dataset.max_seq_len``, and returns 0 |
| 183 | if both of those fail, meaning that number of tokens processed will not be tracked as a part of the training progress tracking. |
no outgoing calls