| 287 | |
| 288 | |
| 289 | class ExamplesIterable(_BaseExamplesIterable): |
| 290 | def __init__( |
| 291 | self, |
| 292 | generate_examples_fn: Callable[..., Iterator[tuple[Key, dict]]], |
| 293 | kwargs: dict, |
| 294 | generate_more_kwargs_fn: Optional[Callable[..., Iterator[dict]]] = None, |
| 295 | sleep_on_threads_shutdown: bool = False, |
| 296 | ): |
| 297 | super().__init__() |
| 298 | self.generate_examples_fn = generate_examples_fn |
| 299 | self.kwargs = kwargs |
| 300 | |
| 301 | # for resharding |
| 302 | self.generate_more_kwargs_fn = generate_more_kwargs_fn |
| 303 | |
| 304 | # for threads shutdowns |
| 305 | self._sleep_on_threads_shutdown = sleep_on_threads_shutdown |
| 306 | |
| 307 | def _init_state_dict(self) -> dict: |
| 308 | self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__} |
| 309 | return self._state_dict |
| 310 | |
| 311 | def __iter__(self): |
| 312 | shard_idx_start = self._state_dict["shard_idx"] if self._state_dict else 0 |
| 313 | for gen_kwargs in islice(_split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards), shard_idx_start, None): |
| 314 | shard_example_idx_start = self._state_dict["shard_example_idx"] if self._state_dict else 0 |
| 315 | for key_example in islice(self.generate_examples_fn(**gen_kwargs), shard_example_idx_start, None): |
| 316 | if self._state_dict: |
| 317 | self._state_dict["shard_example_idx"] += 1 |
| 318 | yield key_example |
| 319 | if self._state_dict: |
| 320 | self._state_dict["shard_idx"] += 1 |
| 321 | self._state_dict["shard_example_idx"] = 0 |
| 322 | |
| 323 | def shuffle_data_sources(self, generator: np.random.Generator) -> "ExamplesIterable": |
| 324 | return ExamplesIterable( |
| 325 | self.generate_examples_fn, |
| 326 | _shuffle_gen_kwargs(deepcopy(generator), self.kwargs), |
| 327 | self.generate_more_kwargs_fn, |
| 328 | self.sleep_on_threads_shutdown, |
| 329 | ) |
| 330 | |
| 331 | def shard_data_sources(self, num_shards: int, index: int, contiguous=True) -> "ExamplesIterable": |
| 332 | """Keep only the requested shard.""" |
| 333 | gen_kwargs_list = _split_gen_kwargs(self.kwargs, max_num_jobs=self.num_shards) |
| 334 | shard_indices = self.split_shard_indices_by_worker(num_shards, index, contiguous=contiguous) |
| 335 | requested_gen_kwargs = _merge_gen_kwargs([gen_kwargs_list[i] for i in shard_indices]) |
| 336 | return ExamplesIterable( |
| 337 | self.generate_examples_fn, |
| 338 | requested_gen_kwargs, |
| 339 | self.generate_more_kwargs_fn, |
| 340 | self.sleep_on_threads_shutdown, |
| 341 | ) |
| 342 | |
| 343 | def reshard_data_sources(self) -> "ExamplesIterable": |
| 344 | """Split shars into more shards if possible.""" |
| 345 | if not self.generate_more_kwargs_fn: |
| 346 | return ExamplesIterable( |
no outgoing calls