Prune layers not required for specified intermediates.
(
self,
indices: Union[int, List[int]] = 1,
prune_norm: bool = False,
prune_head: bool = True,
)
| 1381 | return x, intermediates |
| 1382 | |
| 1383 | def prune_intermediate_layers( |
| 1384 | self, |
| 1385 | indices: Union[int, List[int]] = 1, |
| 1386 | prune_norm: bool = False, |
| 1387 | prune_head: bool = True, |
| 1388 | ): |
| 1389 | """ Prune layers not required for specified intermediates. |
| 1390 | """ |
| 1391 | take_indices, max_index = feature_take_indices(len(self.stages), indices) |
| 1392 | self.stages = self.stages[:max_index + 1] # truncate blocks w/ stem as idx 0 |
| 1393 | if prune_head: |
| 1394 | self.reset_classifier(0, '') |
| 1395 | return take_indices |
| 1396 | |
| 1397 | def forward_features(self, x: torch.Tensor) -> torch.Tensor: |
| 1398 | # input embedding |
nothing calls this directly
no test coverage detected