MCPcopy
hub / github.com/facebookresearch/mmf / load

Method load

pythia/tasks/base_task.py:79–137  ·  view source on GitHub ↗
(self, **opts)

Source from the content-addressed store, hash-verified

77 self.given_datasets = datasets
78
79 def load(self, **opts):
80 self.opts = opts
81 self._process_datasets()
82
83 self.datasets = []
84 self.builders = []
85 available_datasets = self._get_available_datasets()
86
87 self.total_length = 0
88 self.per_dataset_lengths = []
89 self.num_datasets = 0
90
91 for dataset in self.given_datasets:
92 if dataset in available_datasets:
93 builder_class = registry.get_builder_class(dataset)
94
95 if builder_class is None:
96 print("No builder class found for %s." % dataset)
97 continue
98 builder_instance = builder_class()
99
100 if dataset in self.opts["dataset_attributes"]:
101 attributes = self.opts["dataset_attributes"][dataset]
102 else:
103 self.writer.write(
104 "Dataset %s is missing from "
105 "dataset_attributes in config." % dataset,
106 "error",
107 )
108 sys.exit(1)
109
110 dataset_type = self.opts.get("dataset_type", "train")
111 builder_instance.build(dataset_type, attributes)
112 dataset_instance = builder_instance.load(dataset_type, attributes)
113
114 if dataset_instance is None:
115 continue
116
117 self.builders.append(builder_instance)
118 self.datasets.append(dataset_instance)
119 self.per_dataset_lengths.append(len(dataset_instance))
120 self.total_length += len(dataset_instance)
121 else:
122 print(
123 "Dataset %s is not a valid dataset for task %s. Skipping"
124 % (dataset, self.task_name)
125 )
126
127 self.num_datasets = len(self.datasets)
128 self.dataset_probablities = [1 for _ in range(self.num_datasets)]
129 sampling = self.opts.get("dataset_size_proportional_sampling", None)
130
131 if sampling is True:
132 self.dataset_probablities = self.per_dataset_lengths[:]
133 self.dataset_probablities = [
134 prob / self.total_length for prob in self.dataset_probablities
135 ]
136

Callers 15

runFunction · 0.45
extract_bertFunction · 0.45
test_caption_bleu4Method · 0.45
_get_configMethod · 0.45
__init__Method · 0.45
get_itemMethod · 0.45
__init__Method · 0.45
process_answers.pyFile · 0.45
_torch_loadMethod · 0.45
__init__Method · 0.45
print_evalFunction · 0.45
train.pyFile · 0.45

Calls 7

_process_datasetsMethod · 0.95
change_datasetMethod · 0.95
get_builder_classMethod · 0.80
writeMethod · 0.80
getMethod · 0.80
buildMethod · 0.45

Tested by 2

test_caption_bleu4Method · 0.36
_get_configMethod · 0.36