Get the default pattern from a directory or repository by testing all the supported patterns. The first patterns to return a non-empty list of data files is returned. In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS.
(pattern_resolver: Callable[[str], list[str]])
| 255 | |
| 256 | |
| 257 | def _get_data_files_patterns(pattern_resolver: Callable[[str], list[str]]) -> dict[str, list[str]]: |
| 258 | """ |
| 259 | Get the default pattern from a directory or repository by testing all the supported patterns. |
| 260 | The first patterns to return a non-empty list of data files is returned. |
| 261 | |
| 262 | In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS. |
| 263 | """ |
| 264 | # first check the split patterns like data/{split}-00000-of-00001.parquet |
| 265 | for split_pattern in ALL_SPLIT_PATTERNS: |
| 266 | pattern = split_pattern.replace("{split}", "*") |
| 267 | try: |
| 268 | data_files = pattern_resolver(pattern) |
| 269 | except FileNotFoundError: |
| 270 | continue |
| 271 | if len(data_files) > 0: |
| 272 | splits: set[str] = set() |
| 273 | for p in data_files: |
| 274 | p_parts = string_to_dict(xbasename(p), xbasename(split_pattern)) |
| 275 | assert p_parts is not None |
| 276 | splits.add(p_parts["split"]) |
| 277 | |
| 278 | if any(not re.match(_split_re, split) for split in splits): |
| 279 | raise ValueError(f"Split name should match '{_split_re}'' but got '{splits}'.") |
| 280 | sorted_splits = [str(split) for split in DEFAULT_SPLITS if split in splits] + sorted( |
| 281 | splits - {str(split) for split in DEFAULT_SPLITS} |
| 282 | ) |
| 283 | return {split: [split_pattern.format(split=split)] for split in sorted_splits} |
| 284 | # then check the default patterns based on train/valid/test splits |
| 285 | for patterns_dict in ALL_DEFAULT_PATTERNS: |
| 286 | non_empty_splits = [] |
| 287 | for split, patterns in patterns_dict.items(): |
| 288 | for pattern in patterns: |
| 289 | try: |
| 290 | data_files = pattern_resolver(pattern) |
| 291 | except FileNotFoundError: |
| 292 | continue |
| 293 | if len(data_files) > 0: |
| 294 | non_empty_splits.append(split) |
| 295 | break |
| 296 | if non_empty_splits: |
| 297 | return {split: patterns_dict[split] for split in non_empty_splits} |
| 298 | raise FileNotFoundError(f"Couldn't resolve pattern {pattern} with resolver {pattern_resolver}") |
| 299 | |
| 300 | |
| 301 | def resolve_pattern( |