| 39 | |
| 40 | |
| 41 | def build_dataset_from_name(dataset, split=0): |
| 42 | if isinstance(dataset, list): |
| 43 | dataset = dataset[0] |
| 44 | if isinstance(split, list): |
| 45 | split = split[0] |
| 46 | if dataset in SUPPORTED_DATASETS: |
| 47 | path = ".".join(SUPPORTED_DATASETS[dataset].split(".")[:-1]) |
| 48 | module = importlib.import_module(path) |
| 49 | else: |
| 50 | dataset = build_dataset_from_path(dataset) |
| 51 | if dataset is not None: |
| 52 | return dataset |
| 53 | raise NotImplementedError(f"Failed to import {dataset} dataset.") |
| 54 | class_name = SUPPORTED_DATASETS[dataset].split(".")[-1] |
| 55 | dataset_class = getattr(module, class_name) |
| 56 | for key in inspect.signature(dataset_class.__init__).parameters.keys(): |
| 57 | if key == "split": |
| 58 | return dataset_class(split=split) |
| 59 | |
| 60 | return dataset_class() |
| 61 | |
| 62 | |
| 63 | def build_dataset_pretrain(args): |