A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. TensorDict allows you to manipulate a dictionary of Tensors l
| 171 | |
| 172 | @dataclass |
| 173 | class DataProto: |
| 174 | """ |
| 175 | A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. |
| 176 | It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. |
| 177 | TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the |
| 178 | same batch size should be put inside batch. |
| 179 | """ |
| 180 | batch: TensorDict = None |
| 181 | non_tensor_batch: Dict = field(default_factory=dict) |
| 182 | meta_info: Dict = field(default_factory=dict) |
| 183 | |
| 184 | def __post_init__(self): |
| 185 | # perform necessary checking |
| 186 | self.check_consistency() |
| 187 | |
| 188 | def __len__(self): |
| 189 | if self.batch is not None: |
| 190 | return self.batch.batch_size[0] |
| 191 | elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0: |
| 192 | random_key = list(self.non_tensor_batch.keys())[0] |
| 193 | return self.non_tensor_batch[random_key].shape[0] |
| 194 | else: |
| 195 | return 0 |
| 196 | |
| 197 | def __getitem__(self, item): |
| 198 | tensor_data = self.batch[item] |
| 199 | non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()} |
| 200 | return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info) |
| 201 | |
| 202 | def __getstate__(self): |
| 203 | import io |
| 204 | buffer = io.BytesIO() |
| 205 | if tensordict.__version__ >= '0.5.0' and self.batch is not None: |
| 206 | self.batch = self.batch.contiguous() |
| 207 | self.batch = self.batch.consolidate() |
| 208 | torch.save(self.batch, buffer) |
| 209 | buffer_bytes = buffer.getvalue() |
| 210 | return buffer_bytes, self.non_tensor_batch, self.meta_info |
| 211 | |
| 212 | def __setstate__(self, data): |
| 213 | import io |
| 214 | batch_deserialized_bytes, non_tensor_batch, meta_info = data |
| 215 | batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes) |
| 216 | batch = torch.load(batch_deserialized, |
| 217 | weights_only=False, |
| 218 | map_location='cpu' if not torch.cuda.is_available() else None) |
| 219 | self.batch = batch |
| 220 | self.non_tensor_batch = non_tensor_batch |
| 221 | self.meta_info = meta_info |
| 222 | |
| 223 | def save_to_disk(self, filepath): |
| 224 | with open(filepath, 'wb') as f: |
| 225 | pickle.dump(self, f) |
| 226 | |
| 227 | @staticmethod |
| 228 | def load_from_disk(filepath) -> 'DataProto': |
| 229 | with open(filepath, 'rb') as f: |
| 230 | data = pickle.load(f) |
no outgoing calls