MCPcopy
hub / github.com/PRIME-RL/PRIME / select

Method select

training/verl/protocol.py:238–271  ·  view source on GitHub ↗

Select a subset of the DataProto via batch_keys and meta_info_keys Args: batch_keys (list, optional): a list of strings indicating the keys in batch to select meta_info_keys (list, optional): a list of keys indicating the meta info to select Returns:

(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False)

Source from the content-addressed store, hash-verified

236 return self
237
238 def select(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None, deepcopy=False) -> 'DataProto':
239 """Select a subset of the DataProto via batch_keys and meta_info_keys
240
241 Args:
242 batch_keys (list, optional): a list of strings indicating the keys in batch to select
243 meta_info_keys (list, optional): a list of keys indicating the meta info to select
244
245 Returns:
246 DataProto: the DataProto with the selected batch_keys and meta_info_keys
247 """
248 # TODO (zhangchi.usc1992) whether to copy
249 if batch_keys is not None:
250 batch_keys = tuple(batch_keys)
251 sub_batch = self.batch.select(*batch_keys)
252 else:
253 sub_batch = self.batch
254
255 if non_tensor_batch_keys is not None:
256 non_tensor_batch = {key: val for key, val in self.non_tensor_batch.items() if key in non_tensor_batch_keys}
257 else:
258 non_tensor_batch = self.non_tensor_batch
259
260 if deepcopy:
261 non_tensor_batch = copy.deepcopy(non_tensor_batch)
262
263 if meta_info_keys is not None:
264 sub_meta_info = {key: val for key, val in self.meta_info.items() if key in meta_info_keys}
265 else:
266 sub_meta_info = self.meta_info
267
268 if deepcopy:
269 sub_meta_info = copy.deepcopy(sub_meta_info)
270
271 return DataProto(batch=sub_batch, non_tensor_batch=non_tensor_batch, meta_info=sub_meta_info)
272
273 def pop(self, batch_keys=None, non_tensor_batch_keys=None, meta_info_keys=None) -> 'DataProto':
274 """Pop a subset of the DataProto via `batch_keys` and `meta_info_keys`

Callers 10

fitMethod · 0.80
add_to_bufferMethod · 0.80
compute_valuesMethod · 0.80
compute_log_probMethod · 0.80
compute_log_probMethod · 0.80

Calls 1

DataProtoClass · 0.85

Tested by

no test coverage detected