| 27 | |
| 28 | |
| 29 | def flatten_inputs(model, inputs): |
| 30 | full_args_spec = inspect.getfullargspec(model.forward) |
| 31 | args = [] if not full_args_spec.args else full_args_spec.args |
| 32 | args.pop(0) if (args and args[0] in ['self', 'cls']) else args |
| 33 | |
| 34 | default_values = [] if not full_args_spec.defaults else full_args_spec.defaults |
| 35 | args_has_default = args[len(args) - len(default_values):] |
| 36 | args_with_default = dict(zip( |
| 37 | args_has_default, default_values)) if len(args_has_default) else {} |
| 38 | |
| 39 | flat_inputs = [] |
| 40 | for arg_i in args: |
| 41 | if inputs.get(arg_i, None) is not None: |
| 42 | flat_inputs.append(inputs.get(arg_i)) |
| 43 | else: |
| 44 | flat_inputs.append(args_with_default.get(arg_i)) |
| 45 | |
| 46 | return tuple(flat_inputs) |
| 47 | |
| 48 | |
| 49 | def count_flop(): |