Forward pass returning intermediate features. Args: x: Input image tensor. indices: Indices of features to return (0=stem_dct, 1-4=stages). None returns all. norm: Apply norm layer to final intermediate (unused, for API compat). stop_early: St
(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
)
| 646 | return x |
| 647 | |
| 648 | def forward_intermediates( |
| 649 | self, |
| 650 | x: torch.Tensor, |
| 651 | indices: Optional[Union[int, List[int]]] = None, |
| 652 | norm: bool = False, |
| 653 | stop_early: bool = False, |
| 654 | output_fmt: str = 'NCHW', |
| 655 | intermediates_only: bool = False, |
| 656 | ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: |
| 657 | """Forward pass returning intermediate features. |
| 658 | |
| 659 | Args: |
| 660 | x: Input image tensor. |
| 661 | indices: Indices of features to return (0=stem_dct, 1-4=stages). None returns all. |
| 662 | norm: Apply norm layer to final intermediate (unused, for API compat). |
| 663 | stop_early: Stop iterating when last desired intermediate is reached. |
| 664 | output_fmt: Output format, must be 'NCHW'. |
| 665 | intermediates_only: Only return intermediate features. |
| 666 | |
| 667 | Returns: |
| 668 | List of intermediate features or tuple of (final features, intermediates). |
| 669 | """ |
| 670 | assert output_fmt == 'NCHW', 'Output format must be NCHW.' |
| 671 | intermediates = [] |
| 672 | # 5 feature levels: stem_dct (0) + stages 0-3 (1-4) |
| 673 | take_indices, max_index = feature_take_indices(len(self.stages) + 1, indices) |
| 674 | |
| 675 | x = self.stem_dct(x) |
| 676 | if 0 in take_indices: |
| 677 | intermediates.append(x) |
| 678 | |
| 679 | if torch.jit.is_scripting() or not stop_early: |
| 680 | stages = self.stages |
| 681 | else: |
| 682 | # max_index is 0-4, stages are 1-4, so we need max_index stages |
| 683 | stages = self.stages[:max_index] if max_index > 0 else [] |
| 684 | |
| 685 | for feat_idx, stage in enumerate(stages): |
| 686 | if self.grad_checkpointing and not torch.jit.is_scripting(): |
| 687 | x = checkpoint(stage, x) |
| 688 | else: |
| 689 | x = stage(x) |
| 690 | if feat_idx + 1 in take_indices: # +1 because stem is index 0 |
| 691 | intermediates.append(x) |
| 692 | |
| 693 | if intermediates_only: |
| 694 | return intermediates |
| 695 | |
| 696 | return x, intermediates |
| 697 | |
| 698 | def prune_intermediate_layers( |
| 699 | self, |
nothing calls this directly
no test coverage detected