| 10 | |
| 11 | |
| 12 | def setup_imports(): |
| 13 | # Automatically load all of the modules, so that |
| 14 | # they register with registry |
| 15 | root_folder = registry.get("pythia_root", no_warning=True) |
| 16 | |
| 17 | if root_folder is None: |
| 18 | root_folder = os.path.dirname(os.path.abspath(__file__)) |
| 19 | root_folder = os.path.join(root_folder, "..") |
| 20 | |
| 21 | environment_pythia_path = os.environ.get("PYTHIA_PATH") |
| 22 | |
| 23 | if environment_pythia_path is not None: |
| 24 | root_folder = environment_pythia_path |
| 25 | |
| 26 | root_folder = os.path.join(root_folder, "pythia") |
| 27 | registry.register("pythia_path", root_folder) |
| 28 | |
| 29 | trainer_folder = os.path.join(root_folder, "trainers") |
| 30 | trainer_pattern = os.path.join(trainer_folder, "**", "*.py") |
| 31 | tasks_folder = os.path.join(root_folder, "tasks") |
| 32 | tasks_pattern = os.path.join(tasks_folder, "**", "*.py") |
| 33 | model_folder = os.path.join(root_folder, "models") |
| 34 | model_pattern = os.path.join(model_folder, "**", "*.py") |
| 35 | |
| 36 | importlib.import_module("pythia.common.meter") |
| 37 | |
| 38 | files = glob.glob(tasks_pattern, recursive=True) + \ |
| 39 | glob.glob(model_pattern, recursive=True) + \ |
| 40 | glob.glob(trainer_pattern, recursive=True) |
| 41 | |
| 42 | for f in files: |
| 43 | if f.endswith("task.py"): |
| 44 | splits = f.split(os.sep) |
| 45 | task_name = splits[-2] |
| 46 | if task_name == "tasks": |
| 47 | continue |
| 48 | file_name = splits[-1] |
| 49 | module_name = file_name[: file_name.find(".py")] |
| 50 | importlib.import_module("pythia.tasks." + task_name + "." + module_name) |
| 51 | elif f.find("models") != -1: |
| 52 | splits = f.split(os.sep) |
| 53 | file_name = splits[-1] |
| 54 | module_name = file_name[: file_name.find(".py")] |
| 55 | importlib.import_module("pythia.models." + module_name) |
| 56 | elif f.find("trainer") != -1: |
| 57 | splits = f.split(os.sep) |
| 58 | file_name = splits[-1] |
| 59 | module_name = file_name[: file_name.find(".py")] |
| 60 | importlib.import_module("pythia.trainers." + module_name) |
| 61 | elif f.endswith("builder.py"): |
| 62 | splits = f.split(os.sep) |
| 63 | task_name = splits[-3] |
| 64 | dataset_name = splits[-2] |
| 65 | if task_name == "tasks" or dataset_name == "tasks": |
| 66 | continue |
| 67 | file_name = splits[-1] |
| 68 | module_name = file_name[: file_name.find(".py")] |
| 69 | importlib.import_module( |