Process batch and produce inputs for the model.
(batch, args)
| 30 | |
| 31 | |
| 32 | def process_batch(batch, args): |
| 33 | """Process batch and produce inputs for the model.""" |
| 34 | keys = ["text", "label"] |
| 35 | if args.pretrained_bert: |
| 36 | keys += ["padding_mask", "types"] |
| 37 | else: |
| 38 | keys += ["mask", "position"] |
| 39 | if args.cloze_eval: |
| 40 | if args.fast_decode: |
| 41 | keys += ["dec_text", "dec_position", "dec_mask", "dec_target", "dec_logit_mask"] |
| 42 | else: |
| 43 | keys += ["target", "logit_mask"] |
| 44 | if args.segment_length > 0: |
| 45 | keys += ["segment_id"] |
| 46 | if args.continuous_prompt: |
| 47 | keys += ["prompt_pos"] |
| 48 | if args.variable_num_choices: |
| 49 | keys.append("loss_mask") |
| 50 | # Broadcast data. |
| 51 | datatype = torch.int64 |
| 52 | data_b = mpu.broadcast_data(keys, batch, datatype) |
| 53 | |
| 54 | if "padding_mask" in data_b: |
| 55 | attention_mask = data_b['padding_mask'].float().cuda().contiguous() |
| 56 | if args.fp16: |
| 57 | attention_mask = attention_mask.half() |
| 58 | data_b["padding_mask"] = attention_mask |
| 59 | return data_b |
| 60 | |
| 61 | |
| 62 | tokenizer = None |
no test coverage detected