(self, embed_dim, text_cfg, quick_gelu,
l0_module_text, mask_cfg=None)
| 681 | |
| 682 | class TextEncoder(nn.Module): |
| 683 | def __init__(self, embed_dim, text_cfg, quick_gelu, |
| 684 | l0_module_text, mask_cfg=None): |
| 685 | super().__init__() |
| 686 | |
| 687 | act_layer = QuickGELU if quick_gelu else nn.GELU |
| 688 | self.context_length = text_cfg.context_length |
| 689 | |
| 690 | if text_cfg.layers > 0: |
| 691 | self.transformer = Transformer( |
| 692 | width=text_cfg.width, |
| 693 | layers=text_cfg.layers, |
| 694 | heads=text_cfg.heads, |
| 695 | act_layer=act_layer, |
| 696 | ) |
| 697 | else: |
| 698 | self.transformer = None |
| 699 | |
| 700 | self.text_projection = None |
| 701 | if text_cfg.layers > 0: |
| 702 | self.vocab_size = text_cfg.vocab_size |
| 703 | self.token_embedding = nn.Embedding( |
| 704 | text_cfg.vocab_size, text_cfg.width) |
| 705 | self.positional_embedding = nn.Parameter( |
| 706 | torch.empty(self.context_length, text_cfg.width)) |
| 707 | self.ln_final = LayerNorm(text_cfg.width) |
| 708 | if text_cfg.teacher_width > 0: |
| 709 | self.text_projection = nn.Parameter(torch.empty( |
| 710 | text_cfg.width, embed_dim), requires_grad=False) |
| 711 | else: |
| 712 | self.text_projection = nn.Parameter( |
| 713 | torch.empty(text_cfg.width, embed_dim)) |
| 714 | self.register_buffer( |
| 715 | 'attn_mask', self.build_attention_mask(), persistent=False) |
| 716 | else: |
| 717 | self.token_embedding = None |
| 718 | self.init_parameters() |
| 719 | |
| 720 | if l0_module_text: |
| 721 | logging.info('use l0_module_text') |
| 722 | config_mask = Namespace() |
| 723 | config_mask.hidden_size = text_cfg.width |
| 724 | config_mask.intermediate_size = 4 * text_cfg.width |
| 725 | config_mask.num_attention_heads = text_cfg.heads |
| 726 | config_mask.num_hidden_layers = text_cfg.layers |
| 727 | config_mask.sparsity_warmup = mask_cfg.sparsity_warmup |
| 728 | config_mask.sparsity = mask_cfg.sparsity |
| 729 | config_mask.start_sparsity = mask_cfg.start_sparsity |
| 730 | self.l0_module = L0Module(config_mask, lagrangian_warmup=config_mask.sparsity_warmup, start_sparsity=config_mask.start_sparsity, |
| 731 | target_sparsity=config_mask.sparsity, pruning_type=["hidden", "heads", "intermediate"]) |
| 732 | else: |
| 733 | self.l0_module = None |
| 734 | |
| 735 | self.mask = None |
| 736 | |
| 737 | def init_parameters(self): |
| 738 | if self.transformer is not None: |
nothing calls this directly
no test coverage detected