MCPcopy
hub / github.com/huggingface/datasets / ExamplesIterable

Class ExamplesIterable

src/datasets/iterable_dataset.py:289–363  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

287
288
289class ExamplesIterable(_BaseExamplesIterable):
290 def __init__(
291 self,
292 generate_examples_fn: Callable[..., Iterator[tuple[Key, dict]]],
293 kwargs: dict,
294 generate_more_kwargs_fn: Optional[Callable[..., Iterator[dict]]] = None,
295 sleep_on_threads_shutdown: bool = False,
296 ):
297 super().__init__()
298 self.generate_examples_fn = generate_examples_fn
299 self.kwargs = kwargs
300
301 # for resharding
302 self.generate_more_kwargs_fn = generate_more_kwargs_fn
303
304 # for threads shutdowns
305 self._sleep_on_threads_shutdown = sleep_on_threads_shutdown
306
307 def _init_state_dict(self) -> dict:
308 self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__}
309 return self._state_dict
310
311 def __iter__(self):
312 shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0
313 for gen_kwargs in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards), shard_idx_start, None):
314 shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0
315 for key_example in islice(self.generate_examples_fn(**gen_kwargs), shard_example_idx_start, None):
316 if self._state_dict:
317 self._state_dict["shard_example_idx"] += 1
318 yield key_example
319 if self._state_dict:
320 self._state_dict["shard_idx"] += 1
321 self._state_dict["shard_example_idx"] = 0
322
323 def shuffle_data_sources(self, generator: np.random.Generator) -> "ExamplesIterable":
324 return ExamplesIterable(
325 self.generate_examples_fn,
326 _shuffle_gen_kwargs(deepcopy(generator), self.kwargs),
327 self.generate_more_kwargs_fn,
328 self.sleep_on_threads_shutdown,
329 )
330
331 def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable":
332 """Keep only the requested shard."""
333 gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards)
334 shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous)
335 requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices])
336 return ExamplesIterable(
337 self.generate_examples_fn,
338 requested_gen_kwargs,
339 self.generate_more_kwargs_fn,
340 self.sleep_on_threads_shutdown,
341 )
342
343 def reshard_data_sources(self) -> "ExamplesIterable":
344 """Split shars into more shards if possible."""
345 if not self.generate_more_kwargs_fn:
346 return ExamplesIterable(

Calls

no outgoing calls