| 430 | |
| 431 | |
| 432 | class BertLayer(nn.Module): |
| 433 | |
| 434 | def __init__(self, i, config, weights, biases): |
| 435 | super(BertLayer, self).__init__() |
| 436 | self.attention = BertAttention(i, config, weights, biases) |
| 437 | self.intermediate = BertIntermediate(config, weights, biases) |
| 438 | self.output = BertOutput(config, weights, biases) |
| 439 | self.weight = weights |
| 440 | self.biases = biases |
| 441 | |
| 442 | def forward(self, hidden_states, attention_mask, grads, collect_all_grads=False): |
| 443 | attention_output = self.attention(hidden_states, attention_mask) |
| 444 | intermediate_output = self.intermediate(attention_output) |
| 445 | layer_output = self.output(intermediate_output, attention_output) |
| 446 | |
| 447 | if collect_all_grads: |
| 448 | # self.weight[0].register_hook(lambda x, self=self: grads.append([x,"Q_W"])) |
| 449 | # self.biases[0].register_hook(lambda x, self=self: grads.append([x,"Q_B"])) |
| 450 | # self.weight[1].register_hook(lambda x, self=self: grads.append([x,"K_W"])) |
| 451 | # self.biases[1].register_hook(lambda x, self=self: grads.append([x,"K_B"])) |
| 452 | self.weight[2].register_hook(lambda x, self=self: grads.append([x, "V_W"])) |
| 453 | self.biases[2].register_hook(lambda x, self=self: grads.append([x, "V_B"])) |
| 454 | self.weight[3].register_hook(lambda x, self=self: grads.append([x, "O_W"])) |
| 455 | self.biases[3].register_hook(lambda x, self=self: grads.append([x, "O_B"])) |
| 456 | self.attention.output.LayerNorm.weight.register_hook(lambda x, self=self: grads.append([x, "N2_W"])) |
| 457 | self.attention.output.LayerNorm.bias.register_hook(lambda x, self=self: grads.append([x, "N2_B"])) |
| 458 | self.weight[5].register_hook(lambda x, self=self: grads.append([x, "int_W"])) |
| 459 | self.biases[5].register_hook(lambda x, self=self: grads.append([x, "int_B"])) |
| 460 | self.weight[6].register_hook(lambda x, self=self: grads.append([x, "out_W"])) |
| 461 | self.biases[6].register_hook(lambda x, self=self: grads.append([x, "out_B"])) |
| 462 | self.output.LayerNorm.weight.register_hook(lambda x, self=self: grads.append([x, "norm_W"])) |
| 463 | self.output.LayerNorm.bias.register_hook(lambda x, self=self: grads.append([x, "norm_B"])) |
| 464 | |
| 465 | return layer_output |
| 466 | |
| 467 | def get_w(self): |
| 468 | return self.attention.get_w() |
| 469 | |
| 470 | |
| 471 | class BertEncoder(nn.Module): |
no outgoing calls
no test coverage detected
searching dependent graphs…