MCPcopy
hub / github.com/hkust-nlp/simpleRL-reason / DataProto

Class DataProto

verl/protocol.py:173–597  ·  view source on GitHub ↗

A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. TensorDict allows you to manipulate a dictionary of Tensors l

Source from the content-addressed store, hash-verified

171
172@dataclass
173class DataProto:
174 """
175 A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.
176 It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/.
177 TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the
178 same batch size should be put inside batch.
179 """
180 batch: TensorDict = None
181 non_tensor_batch: Dict = field(default_factory=dict)
182 meta_info: Dict = field(default_factory=dict)
183
184 def __post_init__(self):
185 # perform necessary checking
186 self.check_consistency()
187
188 def __len__(self):
189 if self.batch is not None:
190 return self.batch.batch_size[0]
191 elif self.non_tensor_batch is not None and len(self.non_tensor_batch) > 0:
192 random_key = list(self.non_tensor_batch.keys())[0]
193 return self.non_tensor_batch[random_key].shape[0]
194 else:
195 return 0
196
197 def __getitem__(self, item):
198 tensor_data = self.batch[item]
199 non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
200 return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
201
202 def __getstate__(self):
203 import io
204 buffer = io.BytesIO()
205 if tensordict.__version__ >= '0.5.0' and self.batch is not None:
206 self.batch = self.batch.contiguous()
207 self.batch = self.batch.consolidate()
208 torch.save(self.batch, buffer)
209 buffer_bytes = buffer.getvalue()
210 return buffer_bytes, self.non_tensor_batch, self.meta_info
211
212 def __setstate__(self, data):
213 import io
214 batch_deserialized_bytes, non_tensor_batch, meta_info = data
215 batch_deserialized = io.BytesIO(initial_bytes=batch_deserialized_bytes)
216 batch = torch.load(batch_deserialized,
217 weights_only=False,
218 map_location='cpu' if not torch.cuda.is_available() else None)
219 self.batch = batch
220 self.non_tensor_batch = non_tensor_batch
221 self.meta_info = meta_info
222
223 def save_to_disk(self, filepath):
224 with open(filepath, 'wb') as f:
225 pickle.dump(self, f)
226
227 @staticmethod
228 def load_from_disk(filepath) -> 'DataProto':
229 with open(filepath, 'rb') as f:
230 data = pickle.load(f)

Callers 15

get_aux_metricsFunction · 0.90
testFunction · 0.90
client.pyFile · 0.90
train_modelMethod · 0.90
test_lenFunction · 0.90
_filter_batchMethod · 0.90
update_actorMethod · 0.90
update_criticMethod · 0.90
update_actorMethod · 0.90
update_criticMethod · 0.90
_generate_minibatchMethod · 0.90
generate_sequencesMethod · 0.90

Calls

no outgoing calls

Tested by 3

get_aux_metricsFunction · 0.72
testFunction · 0.72
test_lenFunction · 0.72