(self, zs)
| 264 | return numpified_zs |
| 265 | |
| 266 | def calculate_model_size(self, zs): |
| 267 | if zs is None: |
| 268 | return {"pruned_sparsity": 0.0} |
| 269 | |
| 270 | layers = self.num_hidden_layers |
| 271 | hidden_size = self.hidden_size |
| 272 | heads = self.num_attention_heads |
| 273 | device = self.z_logas[self.types[0]].device |
| 274 | |
| 275 | numpified_zs = self.get_z_from_zs(zs) |
| 276 | hidden_z = numpified_zs["hidden"] if "hidden" in numpified_zs.keys() else np.ones([ |
| 277 | hidden_size]) |
| 278 | heads_z = numpified_zs["heads"] if "heads" in numpified_zs.keys() else np.ones([ |
| 279 | layers, 1, heads, 1, 1]) |
| 280 | mha_z = numpified_zs["mha"].reshape(-1, 1, 1, 1, 1) if "mha" in numpified_zs.keys( |
| 281 | ) else np.ones([heads_z.shape[0], 1, 1, 1, 1]) |
| 282 | intermediate_z = numpified_zs["intermediate"] if "intermediate" in numpified_zs.keys( |
| 283 | ) else np.ones([layers, 1, 1, hidden_size * 4]) |
| 284 | ffn_z = numpified_zs["ffn"].reshape(-1, 1, 1, 1) if "ffn" in numpified_zs.keys( |
| 285 | ) else np.ones([heads_z.shape[0], 1, 1, 1]) |
| 286 | |
| 287 | remain_hidden = hidden_z.sum().item() |
| 288 | remain_intermediate = intermediate_z.reshape( |
| 289 | self.num_hidden_layers, self.intermediate_size).sum(-1).tolist() |
| 290 | remain_heads = heads_z.reshape( |
| 291 | self.num_hidden_layers, self.num_attention_heads).sum(-1).tolist() |
| 292 | |
| 293 | heads = np.outer((heads_z * mha_z).reshape(-1), hidden_z).sum().item() |
| 294 | intermediate = np.outer( |
| 295 | (intermediate_z * ffn_z).reshape(-1), hidden_z).sum().item() |
| 296 | |
| 297 | remain_model_size = heads * self.dim_per_head * 4 + intermediate * 2 |
| 298 | |
| 299 | pruned_model_size = self.prunable_model_size - remain_model_size |
| 300 | |
| 301 | results = { |
| 302 | 'mha': mha_z.reshape(-1).astype(int).tolist(), |
| 303 | 'ffn': ffn_z.reshape(-1).astype(int).tolist(), |
| 304 | 'remain_hidden': remain_hidden, |
| 305 | 'remain_intermediate': remain_intermediate, |
| 306 | 'remain_heads': remain_heads, |
| 307 | 'pruned_params': pruned_model_size, |
| 308 | 'remain_params': remain_model_size, |
| 309 | 'pruned_sparsity': pruned_model_size / self.prunable_model_size |
| 310 | } |
| 311 | return results |
| 312 | |
| 313 | def forward(self, soft=True): |
| 314 | zs = {f"{t}_z": [] for t in self.types} |
no test coverage detected