feature_map: (1, h, w, C) is the feature map of a single image.
(
feature_map: torch.Tensor,
img_size,
interpolation="bicubic",
return_pca_stats=False,
pca_stats=None,
skip_components: int = 0,
)
| 595 | |
| 596 | |
| 597 | def get_pca_map( |
| 598 | feature_map: torch.Tensor, |
| 599 | img_size, |
| 600 | interpolation="bicubic", |
| 601 | return_pca_stats=False, |
| 602 | pca_stats=None, |
| 603 | skip_components: int = 0, |
| 604 | ): |
| 605 | """ |
| 606 | feature_map: (1, h, w, C) is the feature map of a single image. |
| 607 | """ |
| 608 | if feature_map.shape[0] != 1: |
| 609 | # make it (1, h, w, C) |
| 610 | feature_map = feature_map[None] |
| 611 | if pca_stats is None: |
| 612 | reduct_mat, color_min, color_max = get_robust_pca( |
| 613 | feature_map.reshape(-1, feature_map.shape[-1]), skip=skip_components, |
| 614 | ) |
| 615 | else: |
| 616 | reduct_mat, color_min, color_max = pca_stats |
| 617 | pca_color = feature_map @ reduct_mat |
| 618 | pca_color = (pca_color - color_min) / (color_max - color_min) |
| 619 | pca_color = pca_color.clamp(0, 1) |
| 620 | pca_color = F.interpolate( |
| 621 | pca_color.permute(0, 3, 1, 2), |
| 622 | size=img_size, |
| 623 | mode=interpolation, |
| 624 | ).permute(0, 2, 3, 1) |
| 625 | pca_color = pca_color.cpu().numpy().squeeze(0) |
| 626 | if return_pca_stats: |
| 627 | return pca_color, (reduct_mat, color_min, color_max) |
| 628 | return pca_color |
| 629 | |
| 630 | |
| 631 | def get_scale_map( |
no test coverage detected
searching dependent graphs…