Given a URL, look for the corresponding dataset in the local cache. If it's not there, download it. Then return the path to the cached file.
(url, cache_dir=None)
| 176 | |
| 177 | |
| 178 | def get_from_cache(url, cache_dir=None): |
| 179 | """ |
| 180 | Given a URL, look for the corresponding dataset in the local cache. |
| 181 | If it's not there, download it. Then return the path to the cached file. |
| 182 | """ |
| 183 | if cache_dir is None: |
| 184 | cache_dir = PYTORCH_PRETRAINED_BERT_CACHE |
| 185 | if sys.version_info[0] == 3 and isinstance(cache_dir, Path): |
| 186 | cache_dir = str(cache_dir) |
| 187 | |
| 188 | if not os.path.exists(cache_dir): |
| 189 | os.makedirs(cache_dir) |
| 190 | |
| 191 | # Get eTag to add to filename, if it exists. |
| 192 | if url.startswith("s3://"): |
| 193 | etag = s3_etag(url) |
| 194 | else: |
| 195 | response = requests.head(url, allow_redirects=True) |
| 196 | if response.status_code != 200: |
| 197 | raise IOError("HEAD request failed for url {} with status code {}" |
| 198 | .format(url, response.status_code)) |
| 199 | etag = response.headers.get("ETag") |
| 200 | |
| 201 | filename = url_to_filename(url, etag) |
| 202 | |
| 203 | # get cache path to put the file |
| 204 | cache_path = os.path.join(cache_dir, filename) |
| 205 | |
| 206 | if not os.path.exists(cache_path): |
| 207 | # Download to temporary file, then copy to cache dir once finished. |
| 208 | # Otherwise you get corrupt cache entries if the download gets interrupted. |
| 209 | with tempfile.NamedTemporaryFile() as temp_file: |
| 210 | logger.info("%s not found in cache, downloading to %s", url, temp_file.name) |
| 211 | |
| 212 | # GET file object |
| 213 | if url.startswith("s3://"): |
| 214 | s3_get(url, temp_file) |
| 215 | else: |
| 216 | http_get(url, temp_file) |
| 217 | |
| 218 | # we are copying the file before closing it, so flush to avoid truncation |
| 219 | temp_file.flush() |
| 220 | # shutil.copyfileobj() starts at the current position, so go to the start |
| 221 | temp_file.seek(0) |
| 222 | |
| 223 | logger.info("copying %s to cache at %s", temp_file.name, cache_path) |
| 224 | with open(cache_path, 'wb') as cache_file: |
| 225 | shutil.copyfileobj(temp_file, cache_file) |
| 226 | |
| 227 | logger.info("creating metadata file for %s", cache_path) |
| 228 | meta = {'url': url, 'etag': etag} |
| 229 | meta_path = cache_path + '.json' |
| 230 | with open(meta_path, 'w', encoding="utf-8") as meta_file: |
| 231 | json.dump(meta, meta_file) |
| 232 | |
| 233 | logger.info("removing temp file %s", temp_file.name) |
| 234 | |
| 235 | return cache_path |
no test coverage detected