| 9 | |
| 10 | |
| 11 | class L0Module(nn.Module): |
| 12 | limit_a, limit_b, epsilon = -.1, 1.1, 1e-6 |
| 13 | all_types = ["hidden_z", "heads_z", "mha_z", "intermediate_z", "ffn_z"] |
| 14 | |
| 15 | def __init__(self, config, |
| 16 | start_sparsity=0.0, |
| 17 | target_sparsity=0.0, |
| 18 | lagrangian_warmup=0, |
| 19 | init_loga=0.5, |
| 20 | temperature=2. / 3., |
| 21 | pruning_type=["hidden", "heads", "intermediate", "layer"], |
| 22 | magical_number=0.8, # from Wang et al. 2020 |
| 23 | ): |
| 24 | super(L0Module, self).__init__() |
| 25 | |
| 26 | self.magical_number = magical_number |
| 27 | self.lagrangian_warmup = lagrangian_warmup |
| 28 | |
| 29 | self.pruning_type = pruning_type |
| 30 | self.start_sparsity = start_sparsity |
| 31 | self.target_sparsity = target_sparsity |
| 32 | self.temperature = temperature |
| 33 | |
| 34 | self.hidden_size = config.hidden_size |
| 35 | self.intermediate_size = config.intermediate_size |
| 36 | self.num_attention_heads = config.num_attention_heads |
| 37 | self.dim_per_head = self.hidden_size // self.num_attention_heads |
| 38 | self.num_hidden_layers = config.num_hidden_layers |
| 39 | |
| 40 | self.params_per_head_layer = self.hidden_size * \ |
| 41 | self.hidden_size * 4 + self.hidden_size * 4 |
| 42 | self.params_per_head = self.params_per_head_layer // self.num_attention_heads |
| 43 | |
| 44 | self.params_per_mlp_layer = self.hidden_size * self.intermediate_size * \ |
| 45 | 2 + self.hidden_size + self.intermediate_size |
| 46 | self.params_per_intermediate_dim = self.params_per_mlp_layer // self.intermediate_size |
| 47 | |
| 48 | # we ignore the parameters in normalization layers (it takes a very small amount) |
| 49 | self.full_model_size = ( |
| 50 | self.params_per_head_layer + self.params_per_mlp_layer) * self.num_hidden_layers |
| 51 | self.prunable_model_size = 0 |
| 52 | |
| 53 | init_loga = init_loga if isinstance(init_loga, float) else 0.5 |
| 54 | self.loga_mean = math.log( |
| 55 | 1.0 - self.epsilon - init_loga) - math.log(init_loga + self.epsilon) |
| 56 | |
| 57 | self.types = [] |
| 58 | self.z_logas = {} |
| 59 | self.parameters_per_dim = {} |
| 60 | self.sizes = {} |
| 61 | self.shapes = {} |
| 62 | |
| 63 | self.hidden_loga = None |
| 64 | self.hidden_type = None |
| 65 | |
| 66 | for t in pruning_type: |
| 67 | self.initialize_one_module(t) |
| 68 |
no outgoing calls
no test coverage detected