()
| 41 | ) |
| 42 | @unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet") |
| 43 | def test_extract_archive(): |
| 44 | # gzip |
| 45 | with tempfile.TemporaryDirectory() as src_dir: |
| 46 | gz_file = "gz_archive" |
| 47 | gz_path = os.path.join(src_dir, gz_file + ".gz") |
| 48 | content = b"test extract archive gzip" |
| 49 | with gzip.open(gz_path, "wb") as f: |
| 50 | f.write(content) |
| 51 | with tempfile.TemporaryDirectory() as dst_dir: |
| 52 | data.utils.extract_archive(gz_path, dst_dir, overwrite=True) |
| 53 | assert os.path.exists(os.path.join(dst_dir, gz_file)) |
| 54 | |
| 55 | # tar |
| 56 | with tempfile.TemporaryDirectory() as src_dir: |
| 57 | tar_file = "tar_archive" |
| 58 | tar_path = os.path.join(src_dir, tar_file + ".tar") |
| 59 | # default encode to utf8 |
| 60 | content = "test extract archive tar\n".encode() |
| 61 | info = tarfile.TarInfo(name="tar_archive") |
| 62 | info.size = len(content) |
| 63 | with tarfile.open(tar_path, "w") as f: |
| 64 | f.addfile(info, io.BytesIO(content)) |
| 65 | with tempfile.TemporaryDirectory() as dst_dir: |
| 66 | data.utils.extract_archive(tar_path, dst_dir, overwrite=True) |
| 67 | assert os.path.exists(os.path.join(dst_dir, tar_file)) |
| 68 | |
| 69 | |
| 70 | @unittest.skipIf( |
no test coverage detected