Forward features that returns intermediates. Args: x: Input image tensor indices: Take last n blocks if int, all if None, select matching indices if sequence norm: Apply norm layer to compatible intermediates stop_early: Stop iterating over b
(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
)
| 785 | self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity() |
| 786 | |
| 787 | def forward_intermediates( |
| 788 | self, |
| 789 | x: torch.Tensor, |
| 790 | indices: Optional[Union[int, List[int]]] = None, |
| 791 | norm: bool = False, |
| 792 | stop_early: bool = False, |
| 793 | output_fmt: str = 'NCHW', |
| 794 | intermediates_only: bool = False, |
| 795 | ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: |
| 796 | """ Forward features that returns intermediates. |
| 797 | |
| 798 | Args: |
| 799 | x: Input image tensor |
| 800 | indices: Take last n blocks if int, all if None, select matching indices if sequence |
| 801 | norm: Apply norm layer to compatible intermediates |
| 802 | stop_early: Stop iterating over blocks when last desired intermediate hit |
| 803 | output_fmt: Shape of intermediate feature outputs |
| 804 | intermediates_only: Only return intermediate features |
| 805 | Returns: |
| 806 | |
| 807 | """ |
| 808 | assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' |
| 809 | intermediates = [] |
| 810 | take_indices, max_index = feature_take_indices(len(self.stages), indices) |
| 811 | |
| 812 | # forward pass |
| 813 | x = self.stem(x) |
| 814 | B, C, H, W = x.shape |
| 815 | if not self.use_conv: |
| 816 | x = x.flatten(2).transpose(1, 2) |
| 817 | |
| 818 | if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript |
| 819 | stages = self.stages |
| 820 | else: |
| 821 | stages = self.stages[:max_index + 1] |
| 822 | for feat_idx, stage in enumerate(stages): |
| 823 | if self.grad_checkpointing and not torch.jit.is_scripting(): |
| 824 | x = checkpoint(stage, x) |
| 825 | else: |
| 826 | x = stage(x) |
| 827 | if feat_idx in take_indices: |
| 828 | if self.use_conv: |
| 829 | intermediates.append(x) |
| 830 | else: |
| 831 | intermediates.append(x.reshape(B, H, W, -1).permute(0, 3, 1, 2)) |
| 832 | H = (H + 2 - 1) // 2 |
| 833 | W = (W + 2 - 1) // 2 |
| 834 | |
| 835 | if intermediates_only: |
| 836 | return intermediates |
| 837 | |
| 838 | return x, intermediates |
| 839 | |
| 840 | def prune_intermediate_layers( |
| 841 | self, |
nothing calls this directly
no test coverage detected