MCPcopy
hub / github.com/deepspeedai/DeepSpeed / BertLayer

Class BertLayer

tests/unit/modeling.py:432–468  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

430
431
432class BertLayer(nn.Module):
433
434 def __init__(self, i, config, weights, biases):
435 super(BertLayer, self).__init__()
436 self.attention = BertAttention(i, config, weights, biases)
437 self.intermediate = BertIntermediate(config, weights, biases)
438 self.output = BertOutput(config, weights, biases)
439 self.weight = weights
440 self.biases = biases
441
442 def forward(self, hidden_states, attention_mask, grads, collect_all_grads=False):
443 attention_output = self.attention(hidden_states, attention_mask)
444 intermediate_output = self.intermediate(attention_output)
445 layer_output = self.output(intermediate_output, attention_output)
446
447 if collect_all_grads:
448 # self.weight[0].register_hook(lambda x, self=self: grads.append([x,"Q_W"]))
449 # self.biases[0].register_hook(lambda x, self=self: grads.append([x,"Q_B"]))
450 # self.weight[1].register_hook(lambda x, self=self: grads.append([x,"K_W"]))
451 # self.biases[1].register_hook(lambda x, self=self: grads.append([x,"K_B"]))
452 self.weight[2].register_hook(lambda x, self=self: grads.append([x, "V_W"]))
453 self.biases[2].register_hook(lambda x, self=self: grads.append([x, "V_B"]))
454 self.weight[3].register_hook(lambda x, self=self: grads.append([x, "O_W"]))
455 self.biases[3].register_hook(lambda x, self=self: grads.append([x, "O_B"]))
456 self.attention.output.LayerNorm.weight.register_hook(lambda x, self=self: grads.append([x, "N2_W"]))
457 self.attention.output.LayerNorm.bias.register_hook(lambda x, self=self: grads.append([x, "N2_B"]))
458 self.weight[5].register_hook(lambda x, self=self: grads.append([x, "int_W"]))
459 self.biases[5].register_hook(lambda x, self=self: grads.append([x, "int_B"]))
460 self.weight[6].register_hook(lambda x, self=self: grads.append([x, "out_W"]))
461 self.biases[6].register_hook(lambda x, self=self: grads.append([x, "out_B"]))
462 self.output.LayerNorm.weight.register_hook(lambda x, self=self: grads.append([x, "norm_W"]))
463 self.output.LayerNorm.bias.register_hook(lambda x, self=self: grads.append([x, "norm_B"]))
464
465 return layer_output
466
467 def get_w(self):
468 return self.attention.get_w()
469
470
471class BertEncoder(nn.Module):

Callers 1

__init__Method · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…