(self, embed_dim, vision_cfg, quick_gelu,
l0_module_image=False,
mask_cfg=None)
| 596 | |
| 597 | class ImageEncoder(nn.Module): |
| 598 | def __init__(self, embed_dim, vision_cfg, quick_gelu, |
| 599 | l0_module_image=False, |
| 600 | mask_cfg=None): |
| 601 | super().__init__() |
| 602 | act_layer = QuickGELU if quick_gelu else nn.GELU |
| 603 | |
| 604 | if vision_cfg.timm_model_name: |
| 605 | self.visual = TimmModel( |
| 606 | vision_cfg.timm_model_name, |
| 607 | pretrained=vision_cfg.timm_model_pretrained, |
| 608 | pool=vision_cfg.timm_pool, |
| 609 | proj=vision_cfg.timm_proj, |
| 610 | embed_dim=embed_dim, |
| 611 | image_size=vision_cfg.image_size |
| 612 | ) |
| 613 | act_layer = nn.GELU # so that text transformer doesn't use QuickGELU w/ timm models |
| 614 | elif isinstance(vision_cfg.layers, (tuple, list)): |
| 615 | vision_heads = vision_cfg.width * 32 // vision_cfg.head_width |
| 616 | self.visual = ModifiedResNet( |
| 617 | layers=vision_cfg.layers, |
| 618 | output_dim=embed_dim, |
| 619 | heads=vision_heads, |
| 620 | image_size=vision_cfg.image_size, |
| 621 | width=vision_cfg.width |
| 622 | ) |
| 623 | else: |
| 624 | vision_heads = vision_cfg.width // vision_cfg.head_width |
| 625 | self.visual = VisualTransformer( |
| 626 | image_size=vision_cfg.image_size, |
| 627 | patch_size=vision_cfg.patch_size, |
| 628 | width=vision_cfg.width, |
| 629 | layers=vision_cfg.layers, |
| 630 | heads=vision_heads, |
| 631 | mlp_ratio=vision_cfg.mlp_ratio, |
| 632 | output_dim=embed_dim, |
| 633 | act_layer=act_layer, |
| 634 | teacher_width=vision_cfg.teacher_width, |
| 635 | ) |
| 636 | self.init_parameters() |
| 637 | |
| 638 | if l0_module_image: |
| 639 | logging.info('use l0_module_vision') |
| 640 | config_mask = Namespace() |
| 641 | config_mask.hidden_size = vision_cfg.width |
| 642 | config_mask.intermediate_size = 4 * vision_cfg.width |
| 643 | config_mask.num_attention_heads = vision_heads |
| 644 | config_mask.num_hidden_layers = vision_cfg.layers |
| 645 | config_mask.sparsity_warmup = mask_cfg.sparsity_warmup |
| 646 | config_mask.sparsity = mask_cfg.sparsity |
| 647 | config_mask.start_sparsity = mask_cfg.start_sparsity |
| 648 | self.l0_module = L0Module(config_mask, lagrangian_warmup=config_mask.sparsity_warmup, start_sparsity=config_mask.start_sparsity, |
| 649 | target_sparsity=config_mask.sparsity, pruning_type=["hidden", "heads", "intermediate"]) |
| 650 | else: |
| 651 | self.l0_module = None |
| 652 | |
| 653 | self.mask = None |
| 654 | |
| 655 | def init_parameters(self): |
nothing calls this directly
no test coverage detected