MCPcopy
hub / github.com/huggingface/pytorch-image-models / forward_intermediates

Method forward_intermediates

timm/models/levit.py:787–838  ·  view source on GitHub ↗

Forward features that returns intermediates. Args: x: Input image tensor indices: Take last n blocks if int, all if None, select matching indices if sequence norm: Apply norm layer to compatible intermediates stop_early: Stop iterating over b

(
            self,
            x: torch.Tensor,
            indices: Optional[Union[int, List[int]]] = None,
            norm: bool = False,
            stop_early: bool = False,
            output_fmt: str = 'NCHW',
            intermediates_only: bool = False,
    )

Source from the content-addressed store, hash-verified

785 self.num_features, num_classes, drop=self.drop_rate) if num_classes > 0 else nn.Identity()
786
787 def forward_intermediates(
788 self,
789 x: torch.Tensor,
790 indices: Optional[Union[int, List[int]]] = None,
791 norm: bool = False,
792 stop_early: bool = False,
793 output_fmt: str = 'NCHW',
794 intermediates_only: bool = False,
795 ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
796 """ Forward features that returns intermediates.
797
798 Args:
799 x: Input image tensor
800 indices: Take last n blocks if int, all if None, select matching indices if sequence
801 norm: Apply norm layer to compatible intermediates
802 stop_early: Stop iterating over blocks when last desired intermediate hit
803 output_fmt: Shape of intermediate feature outputs
804 intermediates_only: Only return intermediate features
805 Returns:
806
807 """
808 assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
809 intermediates = []
810 take_indices, max_index = feature_take_indices(len(self.stages), indices)
811
812 # forward pass
813 x = self.stem(x)
814 B, C, H, W = x.shape
815 if not self.use_conv:
816 x = x.flatten(2).transpose(1, 2)
817
818 if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
819 stages = self.stages
820 else:
821 stages = self.stages[:max_index + 1]
822 for feat_idx, stage in enumerate(stages):
823 if self.grad_checkpointing and not torch.jit.is_scripting():
824 x = checkpoint(stage, x)
825 else:
826 x = stage(x)
827 if feat_idx in take_indices:
828 if self.use_conv:
829 intermediates.append(x)
830 else:
831 intermediates.append(x.reshape(B, H, W, -1).permute(0, 3, 1, 2))
832 H = (H + 2 - 1) // 2
833 W = (W + 2 - 1) // 2
834
835 if intermediates_only:
836 return intermediates
837
838 return x, intermediates
839
840 def prune_intermediate_layers(
841 self,

Callers

nothing calls this directly

Calls 2

feature_take_indicesFunction · 0.85
checkpointFunction · 0.85

Tested by

no test coverage detected