(data, autodownload=True)
| 496 | |
| 497 | |
| 498 | def check_dataset(data, autodownload=True): |
| 499 | # Download, check and/or unzip dataset if not found locally |
| 500 | |
| 501 | # Download (optional) |
| 502 | extract_dir = '' |
| 503 | if isinstance(data, (str, Path)) and (is_zipfile(data) or is_tarfile(data)): |
| 504 | download(data, dir=f'{DATASETS_DIR}/{Path(data).stem}', unzip=True, delete=False, curl=False, threads=1) |
| 505 | data = next((DATASETS_DIR / Path(data).stem).rglob('*.yaml')) |
| 506 | extract_dir, autodownload = data.parent, False |
| 507 | |
| 508 | # Read yaml (optional) |
| 509 | if isinstance(data, (str, Path)): |
| 510 | data = yaml_load(data) # dictionary |
| 511 | |
| 512 | # Checks |
| 513 | for k in 'train', 'val', 'names': |
| 514 | assert k in data, emojis(f"data.yaml '{k}:' field missing ❌") |
| 515 | if isinstance(data['names'], (list, tuple)): # old array format |
| 516 | data['names'] = dict(enumerate(data['names'])) # convert to dict |
| 517 | assert all(isinstance(k, int) for k in data['names'].keys()), 'data.yaml names keys must be integers, i.e. 2: car' |
| 518 | data['nc'] = len(data['names']) |
| 519 | |
| 520 | # Resolve paths |
| 521 | path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.' |
| 522 | if not path.is_absolute(): |
| 523 | path = (ROOT / path).resolve() |
| 524 | data['path'] = path # download scripts |
| 525 | for k in 'train', 'val', 'test': |
| 526 | if data.get(k): # prepend path |
| 527 | if isinstance(data[k], str): |
| 528 | x = (path / data[k]).resolve() |
| 529 | if not x.exists() and data[k].startswith('../'): |
| 530 | x = (path / data[k][3:]).resolve() |
| 531 | data[k] = str(x) |
| 532 | else: |
| 533 | data[k] = [str((path / x).resolve()) for x in data[k]] |
| 534 | |
| 535 | # Parse yaml |
| 536 | train, val, test, s = (data.get(x) for x in ('train', 'val', 'test', 'download')) |
| 537 | if val: |
| 538 | val = [Path(x).resolve() for x in (val if isinstance(val, list) else [val])] # val path |
| 539 | if not all(x.exists() for x in val): |
| 540 | LOGGER.info('\nDataset not found ⚠️, missing paths %s' % [str(x) for x in val if not x.exists()]) |
| 541 | if not s or not autodownload: |
| 542 | raise Exception('Dataset not found ❌') |
| 543 | t = time.time() |
| 544 | if s.startswith('http') and s.endswith('.zip'): # URL |
| 545 | f = Path(s).name # filename |
| 546 | LOGGER.info(f'Downloading {s} to {f}...') |
| 547 | torch.hub.download_url_to_file(s, f) |
| 548 | Path(DATASETS_DIR).mkdir(parents=True, exist_ok=True) # create root |
| 549 | unzip_file(f, path=DATASETS_DIR) # unzip |
| 550 | Path(f).unlink() # remove zip |
| 551 | r = None # success |
| 552 | elif s.startswith('bash '): # bash script |
| 553 | LOGGER.info(f'Running {s} ...') |
| 554 | r = os.system(s) |
| 555 | else: # python script |
no test coverage detected