Download file from Google Drive. See ``tl.files.load_celebA_dataset`` for example. Parameters -------------- ID : str The driver ID. destination : str The destination for save file.
(ID, destination)
| 1265 | |
| 1266 | |
| 1267 | def download_file_from_google_drive(ID, destination): |
| 1268 | """Download file from Google Drive. |
| 1269 | |
| 1270 | See ``tl.files.load_celebA_dataset`` for example. |
| 1271 | |
| 1272 | Parameters |
| 1273 | -------------- |
| 1274 | ID : str |
| 1275 | The driver ID. |
| 1276 | destination : str |
| 1277 | The destination for save file. |
| 1278 | |
| 1279 | """ |
| 1280 | try: |
| 1281 | from tqdm import tqdm |
| 1282 | except ImportError as e: |
| 1283 | print(e) |
| 1284 | raise ImportError("Module tqdm not found. Please install tqdm via pip or other package managers.") |
| 1285 | |
| 1286 | try: |
| 1287 | import requests |
| 1288 | except ImportError as e: |
| 1289 | print(e) |
| 1290 | raise ImportError("Module requests not found. Please install requests via pip or other package managers.") |
| 1291 | |
| 1292 | def save_response_content(response, destination, chunk_size=32 * 1024): |
| 1293 | |
| 1294 | total_size = int(response.headers.get('content-length', 0)) |
| 1295 | with open(destination, "wb") as f: |
| 1296 | for chunk in tqdm(response.iter_content(chunk_size), total=total_size, unit='B', unit_scale=True, |
| 1297 | desc=destination): |
| 1298 | if chunk: # filter out keep-alive new chunks |
| 1299 | f.write(chunk) |
| 1300 | |
| 1301 | def get_confirm_token(response): |
| 1302 | for key, value in response.cookies.items(): |
| 1303 | if key.startswith('download_warning'): |
| 1304 | return value |
| 1305 | return None |
| 1306 | |
| 1307 | URL = "https://docs.google.com/uc?export=download" |
| 1308 | session = requests.Session() |
| 1309 | |
| 1310 | response = session.get(URL, params={'id': ID}, stream=True) |
| 1311 | token = get_confirm_token(response) |
| 1312 | |
| 1313 | if token: |
| 1314 | params = {'id': ID, 'confirm': token} |
| 1315 | response = session.get(URL, params=params, stream=True) |
| 1316 | save_response_content(response, destination) |
| 1317 | |
| 1318 | |
| 1319 | def load_celebA_dataset(path='data'): |
no test coverage detected
searching dependent graphs…