MCPcopy
hub / github.com/mosaicml/composer / DataSpec

Class DataSpec

composer/core/data_spec.py:126–318  ·  view source on GitHub ↗

Specifications for operating and training on data. An example of constructing a :class:`DataSpec` object with a ``batch_transforms`` callable and then using it with :class:`~.Trainer`: .. doctest:: >>> # Construct DataSpec and subtract mean from the batch >>> batch_trans

Source from the content-addressed store, hash-verified

124
125
126class DataSpec:
127 """Specifications for operating and training on data.
128
129 An example of constructing a :class:`DataSpec` object with a ``batch_transforms``
130 callable and then using it with :class:`~.Trainer`:
131
132 .. doctest::
133
134 >>> # Construct DataSpec and subtract mean from the batch
135 >>> batch_transform_fn = lambda xs, ys: (xs.sub_(xs.mean()), ys)
136 >>> train_dspec = DataSpec(train_dataloader, batch_transforms=batch_transform_fn)
137 >>> # The same function can be used for eval dataloader as well
138 >>> eval_dspec = DataSpec(eval_dataloader, batch_transforms=batch_transform_fn)
139 >>> # Use this DataSpec object to construct trainer
140 >>> trainer = Trainer(
141 ... model=model,
142 ... train_dataloader=train_dspec,
143 ... eval_dataloader=eval_dspec,
144 ... optimizers=optimizer,
145 ... max_duration="1ep",
146 ... )
147
148 Args:
149 dataloader (Union[Iterable, torch.utils.data.DataLoader]): The dataloader, which can be any iterable that yields batches.
150
151 num_samples (int, optional): The total number of samples in an epoch, across all ranks. This field is used by
152 the :class:`.Timestamp` (training progress tracker). If not specified, then ``len(dataloader.dataset)`` is
153 used (if this property is available). Otherwise, the dataset is assumed to be unsized.
154
155 num_tokens (int, optional): The total number of tokens in an epoch. This field is used by the
156 :class:`.Timestamp` (training progress tracker).
157
158 batch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
159 batch before it is moved onto the device. For example, this function can be used for CPU-based
160 normalization. It can modify the batch in-place, and it should return the modified batch. If not specified,
161 the batch is not modified.
162
163 microbatch_transforms ((Batch) -> Batch, optional): Function called by the :class:`.Trainer` to modify the
164 microbatch before it is moved onto the device. For example, this function can be used for GPU-based
165 normalization. It can modify the microbatch in-place, and it should return the modified microbatch. If not
166 specified, the microbatch is not modified.
167
168 split_batch ((Batch, (int | float)) -> Sequence[Batch], optional): Function called by the :class:`.Trainer` to
169 split a batch (the first parameter) into microbatches of a given size (the second parameter). If
170 the ``dataloader`` yields batches not of type :class:`torch.Tensor`, Mapping, tuple, or list, then
171 this function must be specified.
172
173 get_num_samples_in_batch ((Batch) -> Union[int, float], optional): Function that is called by the :class:`.Trainer`
174 to get the number of samples in the provided batch.
175
176 By default, if the batch contains tensors that all have the same 0th dim, then the value of the 0th dim will
177 be returned. If the batch contains tensors where the 0th dim differ, then this function must be specified.
178
179 get_num_tokens_in_batch ((Batch) -> int, optional): Function that is called by the :class:`.Trainer` to
180 get the number of tokens in the provided batch.
181
182 By default, it checks for HuggingFace-style dictionary batches with ``input_ids``, and then checks ``dataset.max_seq_len``, and returns 0
183 if both of those fail, meaning that number of tokens processed will not be tracked as a part of the training progress tracking.

Callers 9

predictMethod · 0.90
dataspecMethod · 0.90
test_transformsFunction · 0.90
ensure_data_specFunction · 0.85

Calls

no outgoing calls

Tested by 7

dataspecMethod · 0.72
test_transformsFunction · 0.72