Process batch and produce inputs for the model.
(batch, args)
| 232 | |
| 233 | |
| 234 | def process_batch(batch, args): |
| 235 | """Process batch and produce inputs for the model.""" |
| 236 | if 'mask' in batch: |
| 237 | # finetune SQuAD |
| 238 | batch['attention_mask'] = batch.pop('mask') |
| 239 | batch['position_id'] = batch.pop('position') |
| 240 | tokens = batch['text'].long().cuda() |
| 241 | attention_mask = batch['attention_mask'].long().cuda() |
| 242 | position_ids = batch['position_id'].long().cuda() |
| 243 | if tokens.dim() == 3: |
| 244 | tokens = tokens.squeeze(1) |
| 245 | attention_mask = attention_mask.squeeze(1) |
| 246 | position_ids = position_ids.squeeze(1) |
| 247 | return tokens, attention_mask, position_ids |
| 248 | |
| 249 | |
| 250 | class DecoderEvaluater: |