Given a URL, look for the corresponding dataset in the local cache. If it's not there, download it. Then return the path to the cached file.
(url, cache_dir=None)
| 172 | |
| 173 | |
| 174 | def get_from_cache(url, cache_dir=None): |
| 175 | """ |
| 176 | Given a URL, look for the corresponding dataset in the local cache. |
| 177 | If it's not there, download it. Then return the path to the cached file. |
| 178 | """ |
| 179 | if cache_dir is None: |
| 180 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE |
| 181 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): |
| 182 | cache_dir = str(cache_dir) |
| 183 | |
| 184 | if not os.path.exists(cache_dir): |
| 185 | os.makedirs(cache_dir) |
| 186 | |
| 187 | # Get eTag to add to filename, if it exists. |
| 188 | if url.startswith("s3://"): |
| 189 | etag = s3_etag(url) |
| 190 | else: |
| 191 | response = requests.head(url, allow_redirects=True) |
| 192 | if response.status_code != 200: |
| 193 | raise IOError("HEAD request failed for url {} with status code {}" |
| 194 | .format(url, response.status_code)) |
| 195 | etag = response.headers.get("ETag") |
| 196 | |
| 197 | filename = url_to_filename(url, etag) |
| 198 | |
| 199 | # get cache path to put the file |
| 200 | cache_path = os.path.join(cache_dir, filename) |
| 201 | |
| 202 | if not os.path.exists(cache_path): |
| 203 | # Download to temporary file, then copy to cache dir once finished. |
| 204 | # Otherwise you get corrupt cache entries if the download gets interrupted. |
| 205 | with tempfile.NamedTemporaryFile() as temp_file: |
| 206 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) |
| 207 | |
| 208 | # GET file object |
| 209 | if url.startswith("s3://"): |
| 210 | s3_get(url, temp_file) |
| 211 | else: |
| 212 | http_get(url, temp_file) |
| 213 | |
| 214 | # we are copying the file before closing it, so flush to avoid truncation |
| 215 | temp_file.flush() |
| 216 | # shutil.copyfileobj() starts at the current position, so go to the start |
| 217 | temp_file.seek(0) |
| 218 | |
| 219 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) |
| 220 | with open(cache_path, 'wb') as cache_file: |
| 221 | shutil.copyfileobj(temp_file, cache_file) |
| 222 | |
| 223 | logger.info("creating metadata file for %s", cache_path) |
| 224 | meta = {'url': url, 'etag': etag} |
| 225 | meta_path = cache_path + '.json' |
| 226 | with open(meta_path, 'w', encoding="utf-8") as meta_file: |
| 227 | json.dump(meta, meta_file) |
| 228 | |
| 229 | logger.info("removing temp file %s", temp_file.name) |
| 230 | |
| 231 | return cache_path |
no test coverage detected