MCPcopy
hub / github.com/huggingface/pytorch-image-models / forward_intermediates

Method forward_intermediates

timm/models/csatv2.py:648–696  ·  view source on GitHub ↗

Forward pass returning intermediate features. Args: x: Input image tensor. indices: Indices of features to return (0=stem_dct, 1-4=stages). None returns all. norm: Apply norm layer to final intermediate (unused, for API compat). stop_early: St

(
            self,
            x: torch.Tensor,
            indices: Optional[Union[int, List[int]]] = None,
            norm: bool = False,
            stop_early: bool = False,
            output_fmt: str = 'NCHW',
            intermediates_only: bool = False,
    )

Source from the content-addressed store, hash-verified

646 return x
647
648 def forward_intermediates(
649 self,
650 x: torch.Tensor,
651 indices: Optional[Union[int, List[int]]] = None,
652 norm: bool = False,
653 stop_early: bool = False,
654 output_fmt: str = 'NCHW',
655 intermediates_only: bool = False,
656 ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
657 """Forward pass returning intermediate features.
658
659 Args:
660 x: Input image tensor.
661 indices: Indices of features to return (0=stem_dct, 1-4=stages). None returns all.
662 norm: Apply norm layer to final intermediate (unused, for API compat).
663 stop_early: Stop iterating when last desired intermediate is reached.
664 output_fmt: Output format, must be 'NCHW'.
665 intermediates_only: Only return intermediate features.
666
667 Returns:
668 List of intermediate features or tuple of (final features, intermediates).
669 """
670 assert output_fmt == 'NCHW', 'Output format must be NCHW.'
671 intermediates = []
672 # 5 feature levels: stem_dct (0) + stages 0-3 (1-4)
673 take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices)
674
675 x = self.stem_dct(x)
676 if 0 in take_indices:
677 intermediates.append(x)
678
679 if torch.jit.is_scripting() or not stop_early:
680 stages = self.stages
681 else:
682 # max_index is 0-4, stages are 1-4, so we need max_index stages
683 stages = self.stages[:max_index] if max_index > 0 else []
684
685 for feat_idx, stage in enumerate(stages):
686 if self.grad_checkpointing and not torch.jit.is_scripting():
687 x = checkpoint(stage, x)
688 else:
689 x = stage(x)
690 if feat_idx + 1 in take_indices: # +1 because stem is index 0
691 intermediates.append(x)
692
693 if intermediates_only:
694 return intermediates
695
696 return x, intermediates
697
698 def prune_intermediate_layers(
699 self,

Callers

nothing calls this directly

Calls 2

feature_take_indicesFunction · 0.90
checkpointFunction · 0.90

Tested by

no test coverage detected