Forward features that returns intermediates. Args: x: Input image tensor indices: Take last n blocks if int, all if None, select matching indices if sequence norm: Apply norm layer to compatible intermediates stop_early: Stop iterating over b
(
self,
x: torch.Tensor,
indices: Optional[Union[int, List[int]]] = None,
norm: bool = False,
stop_early: bool = False,
output_fmt: str = 'NCHW',
intermediates_only: bool = False,
)
| 566 | self.head.reset(num_classes, global_pool, device=device, dtype=dtype) |
| 567 | |
| 568 | def forward_intermediates( |
| 569 | self, |
| 570 | x: torch.Tensor, |
| 571 | indices: Optional[Union[int, List[int]]] = None, |
| 572 | norm: bool = False, |
| 573 | stop_early: bool = False, |
| 574 | output_fmt: str = 'NCHW', |
| 575 | intermediates_only: bool = False, |
| 576 | ) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]: |
| 577 | """ Forward features that returns intermediates. |
| 578 | |
| 579 | Args: |
| 580 | x: Input image tensor |
| 581 | indices: Take last n blocks if int, all if None, select matching indices if sequence |
| 582 | norm: Apply norm layer to compatible intermediates |
| 583 | stop_early: Stop iterating over blocks when last desired intermediate hit |
| 584 | output_fmt: Shape of intermediate feature outputs |
| 585 | intermediates_only: Only return intermediate features |
| 586 | Returns: |
| 587 | |
| 588 | """ |
| 589 | assert output_fmt in ('NCHW',), 'Output shape must be NCHW.' |
| 590 | intermediates = [] |
| 591 | take_indices, max_index = feature_take_indices(len(self.stages), indices) |
| 592 | |
| 593 | # forward pass |
| 594 | x = self.stem(x) |
| 595 | if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript |
| 596 | stages = self.stages |
| 597 | else: |
| 598 | stages = self.stages[:max_index + 1] |
| 599 | |
| 600 | for feat_idx, stage in enumerate(stages): |
| 601 | x = stage(x) |
| 602 | if feat_idx in take_indices: |
| 603 | intermediates.append(x) |
| 604 | |
| 605 | if intermediates_only: |
| 606 | return intermediates |
| 607 | |
| 608 | return x, intermediates |
| 609 | |
| 610 | def prune_intermediate_layers( |
| 611 | self, |
nothing calls this directly
no test coverage detected