| 1147 | ``` |
| 1148 | """ |
| 1149 | def __init__(self, config, num_choices): |
| 1150 | super(BertForMultipleChoice, self).__init__(config) |
| 1151 | self.num_choices = num_choices |
| 1152 | self.bert = BertModel(config) |
| 1153 | self.dropout = nn.Dropout(config.hidden_dropout_prob) |
| 1154 | self.classifier = nn.Linear(config.hidden_size, 1) |
| 1155 | self.apply(self.init_bert_weights) |
| 1156 | |
| 1157 | def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None, checkpoint_activations=False): |
| 1158 | flat_input_ids = input_ids.view(-1, input_ids.size(-1)) |