| 401 | ) |
| 402 | @unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet") |
| 403 | def test_extract_archive(): |
| 404 | # gzip |
| 405 | with tempfile.TemporaryDirectory() as src_dir: |
| 406 | gz_file = "gz_archive" |
| 407 | gz_path = os.path.join(src_dir, gz_file + ".gz") |
| 408 | content = b"test extract archive gzip" |
| 409 | with gzip.open(gz_path, "wb") as f: |
| 410 | f.write(content) |
| 411 | with tempfile.TemporaryDirectory() as dst_dir: |
| 412 | data.utils.extract_archive(gz_path, dst_dir, overwrite=True) |
| 413 | assert os.path.exists(os.path.join(dst_dir, gz_file)) |
| 414 | |
| 415 | # tar |
| 416 | with tempfile.TemporaryDirectory() as src_dir: |
| 417 | tar_file = "tar_archive" |
| 418 | tar_path = os.path.join(src_dir, tar_file + ".tar") |
| 419 | # default encode to utf8 |
| 420 | content = "test extract archive tar\n".encode() |
| 421 | info = tarfile.TarInfo(name="tar_archive") |
| 422 | info.size = len(content) |
| 423 | with tarfile.open(tar_path, "w") as f: |
| 424 | f.addfile(info, io.BytesIO(content)) |
| 425 | with tempfile.TemporaryDirectory() as dst_dir: |
| 426 | data.utils.extract_archive(tar_path, dst_dir, overwrite=True) |
| 427 | assert os.path.exists(os.path.join(dst_dir, tar_file)) |
| 428 | |
| 429 | |
| 430 | def _test_construct_graphs_node_ids(): |