MCPcopy
hub / github.com/OpenPPL/ppq / prepare_input

Method prepare_input

ppq/executor/base.py:125–152  ·  view source on GitHub ↗
(self, inputs: Union[dict, list, torch.Tensor])

Source from the content-addressed store, hash-verified

123 self._executing_order = self._graph.topological_sort()
124
125 def prepare_input(self, inputs: Union[dict, list, torch.Tensor]):
126 assert type(inputs) in (dict, list, torch.Tensor), \
127 f'Input format misunderstood. Except either dict, list or tensor; while {type(inputs)} was given.'
128
129 inputs_dictionary = self._graph.inputs
130 if len(inputs_dictionary) == 0:
131 assert inputs is None, 'Graph do not need any inputs. please set your inputs to be None.'
132 return None
133
134 if isinstance(inputs, torch.Tensor):
135 assert len(inputs_dictionary) == 1, \
136 'Graph needs more than one input, while only one tensor was given.'
137 return {list(inputs_dictionary.keys())[0]: inputs}
138
139 elif isinstance(inputs, list):
140 assert len(inputs_dictionary) == len(inputs), \
141 f'Inputs format misunderstood. Given inputs has '\
142 f'{len(inputs)} elements, while graph needs {len(inputs_dictionary)}'
143 return {key: inputs[idx] for idx, key in enumerate(inputs_dictionary)}
144
145 elif isinstance(inputs, dict):
146 assert len(inputs_dictionary) == len(inputs), \
147 f'Inputs format misunderstood. Given inputs has '\
148 f'{len(inputs)} elements, while graph needs {len(inputs_dictionary)}'
149 return inputs
150
151 else:
152 raise Exception('Oops, you can never reach here.')
153
154 @ abstractmethod
155 def forward(

Callers 1

__forwardMethod · 0.80

Calls

no outgoing calls

Tested by

no test coverage detected