| 123 | self._executing_order = self._graph.topological_sort() |
| 124 | |
| 125 | def prepare_input(self, inputs: Union[dict, list, torch.Tensor]): |
| 126 | assert type(inputs) in (dict, list, torch.Tensor), \ |
| 127 | f'Input format misunderstood. Except either dict, list or tensor; while {type(inputs)} was given.' |
| 128 | |
| 129 | inputs_dictionary = self._graph.inputs |
| 130 | if len(inputs_dictionary) == 0: |
| 131 | assert inputs is None, 'Graph do not need any inputs. please set your inputs to be None.' |
| 132 | return None |
| 133 | |
| 134 | if isinstance(inputs, torch.Tensor): |
| 135 | assert len(inputs_dictionary) == 1, \ |
| 136 | 'Graph needs more than one input, while only one tensor was given.' |
| 137 | return {list(inputs_dictionary.keys())[0]: inputs} |
| 138 | |
| 139 | elif isinstance(inputs, list): |
| 140 | assert len(inputs_dictionary) == len(inputs), \ |
| 141 | f'Inputs format misunderstood. Given inputs has '\ |
| 142 | f'{len(inputs)} elements, while graph needs {len(inputs_dictionary)}' |
| 143 | return {key: inputs[idx] for idx, key in enumerate(inputs_dictionary)} |
| 144 | |
| 145 | elif isinstance(inputs, dict): |
| 146 | assert len(inputs_dictionary) == len(inputs), \ |
| 147 | f'Inputs format misunderstood. Given inputs has '\ |
| 148 | f'{len(inputs)} elements, while graph needs {len(inputs_dictionary)}' |
| 149 | return inputs |
| 150 | |
| 151 | else: |
| 152 | raise Exception('Oops, you can never reach here.') |
| 153 | |
| 154 | @ abstractmethod |
| 155 | def forward( |