MCPcopy
hub / github.com/facebookresearch/mmf / setup_imports

Function setup_imports

tools/run.py:12–71  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

10
11
12def setup_imports():
13 # Automatically load all of the modules, so that
14 # they register with registry
15 root_folder = registry.get("pythia_root", no_warning=True)
16
17 if root_folder is None:
18 root_folder = os.path.dirname(os.path.abspath(__file__))
19 root_folder = os.path.join(root_folder, "..")
20
21 environment_pythia_path = os.environ.get("PYTHIA_PATH")
22
23 if environment_pythia_path is not None:
24 root_folder = environment_pythia_path
25
26 root_folder = os.path.join(root_folder, "pythia")
27 registry.register("pythia_path", root_folder)
28
29 trainer_folder = os.path.join(root_folder, "trainers")
30 trainer_pattern = os.path.join(trainer_folder, "**", "*.py")
31 tasks_folder = os.path.join(root_folder, "tasks")
32 tasks_pattern = os.path.join(tasks_folder, "**", "*.py")
33 model_folder = os.path.join(root_folder, "models")
34 model_pattern = os.path.join(model_folder, "**", "*.py")
35
36 importlib.import_module("pythia.common.meter")
37
38 files = glob.glob(tasks_pattern, recursive=True) + \
39 glob.glob(model_pattern, recursive=True) + \
40 glob.glob(trainer_pattern, recursive=True)
41
42 for f in files:
43 if f.endswith("task.py"):
44 splits = f.split(os.sep)
45 task_name = splits[-2]
46 if task_name == "tasks":
47 continue
48 file_name = splits[-1]
49 module_name = file_name[: file_name.find(".py")]
50 importlib.import_module("pythia.tasks." + task_name + "." + module_name)
51 elif f.find("models") != -1:
52 splits = f.split(os.sep)
53 file_name = splits[-1]
54 module_name = file_name[: file_name.find(".py")]
55 importlib.import_module("pythia.models." + module_name)
56 elif f.find("trainer") != -1:
57 splits = f.split(os.sep)
58 file_name = splits[-1]
59 module_name = file_name[: file_name.find(".py")]
60 importlib.import_module("pythia.trainers." + module_name)
61 elif f.endswith("builder.py"):
62 splits = f.split(os.sep)
63 task_name = splits[-3]
64 dataset_name = splits[-2]
65 if task_name == "tasks" or dataset_name == "tasks":
66 continue
67 file_name = splits[-1]
68 module_name = file_name[: file_name.find(".py")]
69 importlib.import_module(

Callers 1

runFunction · 0.85

Calls 2

getMethod · 0.80
registerMethod · 0.80

Tested by

no test coverage detected