(
model: nn.Module, inputs: list, mode: str, **kwargs
)
| 126 | |
| 127 | |
| 128 | def _wrapper_count_operators( |
| 129 | model: nn.Module, inputs: list, mode: str, **kwargs |
| 130 | ) -> typing.DefaultDict[str, float]: |
| 131 | # ignore some ops |
| 132 | supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS} |
| 133 | supported_ops.update(kwargs.pop("supported_ops", {})) |
| 134 | kwargs["supported_ops"] = supported_ops |
| 135 | |
| 136 | assert len(inputs) == 1, "Please use batch size=1" |
| 137 | tensor_input = inputs[0]["image"] |
| 138 | inputs = [{"image": tensor_input}] # remove other keys, in case there are any |
| 139 | |
| 140 | old_train = model.training |
| 141 | if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)): |
| 142 | model = model.module |
| 143 | wrapper = TracingAdapter(model, inputs) |
| 144 | wrapper.eval() |
| 145 | if mode == FLOPS_MODE: |
| 146 | ret = flop_count(wrapper, (tensor_input,), **kwargs) |
| 147 | elif mode == ACTIVATIONS_MODE: |
| 148 | ret = activation_count(wrapper, (tensor_input,), **kwargs) |
| 149 | else: |
| 150 | raise NotImplementedError("Count for mode {} is not supported yet.".format(mode)) |
| 151 | # compatible with change in fvcore |
| 152 | if isinstance(ret, tuple): |
| 153 | ret = ret[0] |
| 154 | model.train(old_train) |
| 155 | return ret |
| 156 | |
| 157 | |
| 158 | def find_unused_parameters(model: nn.Module, inputs: Any) -> List[str]: |
no test coverage detected