MCPcopy
hub / github.com/PRIME-RL/PRIME / DataProto

Class DataProto

training/verl/protocol.py:99–450  ·  view source on GitHub ↗

A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions. It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/. TensorDict allows you to manipulate a dictionary of Tensors l

Source from the content-addressed store, hash-verified

97
98@dataclass
99class DataProto:
100 """
101 A DataProto is a data structure that aims to provide a standard protocol for data exchange between functions.
102 It contains a batch (TensorDict) and a meta_info (Dict). The batch is a TensorDict https://pytorch.org/tensordict/.
103 TensorDict allows you to manipulate a dictionary of Tensors like a single Tensor. Ideally, the tensors with the
104 same batch size should be put inside batch.
105 """
106 batch: TensorDict = None
107 non_tensor_batch: Dict = field(default_factory=dict)
108 meta_info: Dict = field(default_factory=dict)
109
110 def __post_init__(self):
111 # perform necessary checking
112 self.check_consistency()
113
114 def __len__(self):
115 return self.batch.batch_size[0]
116
117 def __getitem__(self, item):
118 tensor_data = self.batch[item]
119 non_tensor_data = {key: val[item] for key, val in self.non_tensor_batch.items()}
120 return DataProtoItem(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
121
122 def slice(self, index):
123 tensor_data = self.batch[index]
124 non_tensor_data = {key: val[index] for key, val in self.non_tensor_batch.items()}
125 return DataProto(batch=tensor_data, non_tensor_batch=non_tensor_data, meta_info=self.meta_info)
126
127 def slice_batch(self, start, length, dim=0):
128 """
129 Note that this operation is in-place
130 """
131 for key, val in self.batch.items():
132 self.batch[key] = val.narrow(start=start, length=length, dim=dim)
133
134
135 def __getstate__(self):
136 import io
137 buffer = io.BytesIO()
138 if tensordict.__version__ >= '0.5.0' and self.batch is not None:
139 self.batch = self.batch.contiguous()
140 self.batch = self.batch.consolidate()
141 torch.save(self.batch, buffer)
142 return buffer, self.non_tensor_batch, self.meta_info
143
144 def __setstate__(self, data):
145 batch_deserialized, non_tensor_batch, meta_info = data
146 batch_deserialized.seek(0)
147 batch = torch.load(batch_deserialized,
148 weights_only=False,
149 map_location='cpu' if not torch.cuda.is_available() else None)
150 self.batch = batch
151 self.non_tensor_batch = non_tensor_batch
152 self.meta_info = meta_info
153
154 def check_consistency(self):
155 """Check the consistency of the DataProto. Mainly for batch and non_tensor_batch
156 We expose this function as a public one so that user can call themselves directly

Callers 13

update_actorMethod · 0.90
update_criticMethod · 0.90
update_actorMethod · 0.90
update_criticMethod · 0.90
_generate_minibatchMethod · 0.90
generate_sequencesMethod · 0.90
generate_sequencesMethod · 0.90
compute_rewardMethod · 0.90
collate_fnFunction · 0.85
sliceMethod · 0.85
selectMethod · 0.85
chunkMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected