| 43 | |
| 44 | |
| 45 | class BertModel(torch.nn.Module): |
| 46 | |
| 47 | def __init__(self, args): |
| 48 | super(BertModel, self).__init__() |
| 49 | if args.pretrained_bert: |
| 50 | self.model = BertForPreTraining.from_pretrained( |
| 51 | args.tokenizer_model_type, |
| 52 | cache_dir=args.cache_dir, |
| 53 | fp32_layernorm=args.fp32_layernorm, |
| 54 | fp32_embedding=args.fp32_embedding, |
| 55 | layernorm_epsilon=args.layernorm_epsilon) |
| 56 | else: |
| 57 | if args.intermediate_size is None: |
| 58 | intermediate_size = 4 * args.hidden_size |
| 59 | else: |
| 60 | intermediate_size = args.intermediate_size |
| 61 | self.config = BertConfig( |
| 62 | args.tokenizer_num_tokens, |
| 63 | hidden_size=args.hidden_size, |
| 64 | num_hidden_layers=args.num_layers, |
| 65 | num_attention_heads=args.num_attention_heads, |
| 66 | intermediate_size=intermediate_size, |
| 67 | hidden_dropout_prob=args.hidden_dropout, |
| 68 | attention_probs_dropout_prob=args.attention_dropout, |
| 69 | max_position_embeddings=args.max_position_embeddings, |
| 70 | type_vocab_size=args.tokenizer_num_type_tokens, |
| 71 | fp32_layernorm=args.fp32_layernorm, |
| 72 | fp32_embedding=args.fp32_embedding, |
| 73 | fp32_tokentypes=args.fp32_tokentypes, |
| 74 | layernorm_epsilon=args.layernorm_epsilon, |
| 75 | deep_init=args.deep_init) |
| 76 | self.model = BertForPreTraining(self.config) |
| 77 | |
| 78 | def forward(self, input_tokens, token_type_ids=None, |
| 79 | attention_mask=None, checkpoint_activations=False): |
| 80 | return self.model( |
| 81 | input_tokens, token_type_ids, attention_mask, |
| 82 | checkpoint_activations=checkpoint_activations) |
| 83 | |
| 84 | def state_dict(self, destination=None, prefix='', keep_vars=False): |
| 85 | return self.model.state_dict(destination=destination, prefix=prefix, |
| 86 | keep_vars=keep_vars) |
| 87 | |
| 88 | def load_state_dict(self, state_dict, strict=True): |
| 89 | return self.model.load_state_dict(state_dict, strict=strict) |
| 90 | |