| 788 | |
| 789 | |
| 790 | def download_pretrained_from_url( |
| 791 | url: str, |
| 792 | cache_dir: Optional[str] = None, |
| 793 | ): |
| 794 | if not cache_dir: |
| 795 | cache_dir = os.path.expanduser("~/.cache/clip") |
| 796 | os.makedirs(cache_dir, exist_ok=True) |
| 797 | filename = os.path.basename(url) |
| 798 | |
| 799 | if 'openaipublic' in url: |
| 800 | expected_sha256 = url.split("/")[-2] |
| 801 | elif 'mlfoundations' in url: |
| 802 | expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] |
| 803 | else: |
| 804 | expected_sha256 = '' |
| 805 | |
| 806 | download_target = os.path.join(cache_dir, filename) |
| 807 | |
| 808 | if os.path.exists(download_target) and not os.path.isfile(download_target): |
| 809 | raise RuntimeError(f"{download_target} exists and is not a regular file") |
| 810 | |
| 811 | if os.path.isfile(download_target): |
| 812 | if expected_sha256: |
| 813 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): |
| 814 | return download_target |
| 815 | else: |
| 816 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") |
| 817 | else: |
| 818 | return download_target |
| 819 | |
| 820 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: |
| 821 | with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: |
| 822 | while True: |
| 823 | buffer = source.read(8192) |
| 824 | if not buffer: |
| 825 | break |
| 826 | |
| 827 | output.write(buffer) |
| 828 | loop.update(len(buffer)) |
| 829 | |
| 830 | if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): |
| 831 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") |
| 832 | |
| 833 | return download_target |
| 834 | |
| 835 | |
| 836 | def has_hf_hub(necessary=False): |