MCPcopy Index your code
hub / github.com/THUDM/GLM / process_batch

Function process_batch

finetune_glm.py:32–59  ·  view source on GitHub ↗

Process batch and produce inputs for the model.

(batch, args)

Source from the content-addressed store, hash-verified

30
31
32def process_batch(batch, args):
33 """Process batch and produce inputs for the model."""
34 keys = ["text", "label"]
35 if args.pretrained_bert:
36 keys += ["padding_mask", "types"]
37 else:
38 keys += ["mask", "position"]
39 if args.cloze_eval:
40 if args.fast_decode:
41 keys += ["dec_text", "dec_position", "dec_mask", "dec_target", "dec_logit_mask"]
42 else:
43 keys += ["target", "logit_mask"]
44 if args.segment_length > 0:
45 keys += ["segment_id"]
46 if args.continuous_prompt:
47 keys += ["prompt_pos"]
48 if args.variable_num_choices:
49 keys.append("loss_mask")
50 # Broadcast data.
51 datatype = torch.int64
52 data_b = mpu.broadcast_data(keys, batch, datatype)
53
54 if "padding_mask" in data_b:
55 attention_mask = data_b['padding_mask'].float().cuda().contiguous()
56 if args.fp16:
57 attention_mask = attention_mask.half()
58 data_b["padding_mask"] = attention_mask
59 return data_b
60
61
62tokenizer = None

Callers 2

multichoice_evaluateFunction · 0.90
finetune_forward_stepFunction · 0.70

Calls 1

appendMethod · 0.80

Tested by

no test coverage detected