MCPcopy
hub / github.com/huggingface/pytorch-image-models / forward_intermediates

Method forward_intermediates

timm/models/hgnet.py:568–608  ·  view source on GitHub ↗

Forward features that returns intermediates. Args: x: Input image tensor indices: Take last n blocks if int, all if None, select matching indices if sequence norm: Apply norm layer to compatible intermediates stop_early: Stop iterating over b

(
            self,
            x: torch.Tensor,
            indices: Optional[Union[int, List[int]]] = None,
            norm: bool = False,
            stop_early: bool = False,
            output_fmt: str = 'NCHW',
            intermediates_only: bool = False,
    )

Source from the content-addressed store, hash-verified

566 self.head.reset(num_classes, global_pool, device=device, dtype=dtype)
567
568 def forward_intermediates(
569 self,
570 x: torch.Tensor,
571 indices: Optional[Union[int, List[int]]] = None,
572 norm: bool = False,
573 stop_early: bool = False,
574 output_fmt: str = 'NCHW',
575 intermediates_only: bool = False,
576 ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
577 """ Forward features that returns intermediates.
578
579 Args:
580 x: Input image tensor
581 indices: Take last n blocks if int, all if None, select matching indices if sequence
582 norm: Apply norm layer to compatible intermediates
583 stop_early: Stop iterating over blocks when last desired intermediate hit
584 output_fmt: Shape of intermediate feature outputs
585 intermediates_only: Only return intermediate features
586 Returns:
587
588 """
589 assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
590 intermediates = []
591 take_indices, max_index = feature_take_indices(len(self.stages), indices)
592
593 # forward pass
594 x = self.stem(x)
595 if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
596 stages = self.stages
597 else:
598 stages = self.stages[:max_index + 1]
599
600 for feat_idx, stage in enumerate(stages):
601 x = stage(x)
602 if feat_idx in take_indices:
603 intermediates.append(x)
604
605 if intermediates_only:
606 return intermediates
607
608 return x, intermediates
609
610 def prune_intermediate_layers(
611 self,

Callers

nothing calls this directly

Calls 1

feature_take_indicesFunction · 0.85

Tested by

no test coverage detected