MCPcopy Index your code
hub / github.com/THUDM/GLM / get_batch

Function get_batch

generate_samples.py:62–84  ·  view source on GitHub ↗
(context_tokens, device, args)

Source from the content-addressed store, hash-verified

60
61
62def get_batch(context_tokens, device, args):
63 tokens = context_tokens
64 tokens = tokens.view(args.batch_size, -1).contiguous()
65 tokens = tokens.to(device)
66
67 # Get the masks and postition ids.
68 if args.block_lm:
69 attention_mask = torch.tensor([tokens.size(1)], device=device, dtype=torch.long)
70 position_ids = torch.arange(tokens.size(1), device=device, dtype=torch.long)
71 if not args.no_block_position:
72 block_position_ids = torch.zeros(tokens.size(1), device=device, dtype=torch.long)
73 position_ids = torch.stack((position_ids, block_position_ids), dim=0)
74 position_ids = position_ids.unsqueeze(0)
75 else:
76 attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
77 tokens,
78 args.eod_token,
79 reset_position_ids=False,
80 reset_attention_mask=False,
81 set_loss_mask=False,
82 mem_length=args.mem_length)
83
84 return tokens, attention_mask, position_ids
85
86
87def sample_sequence(model, tokenizer, context_tokens, context_length, args, device, mems=None, end_tokens=None):

Callers 2

sample_sequenceFunction · 0.70
generate_samplesFunction · 0.70

Calls 1

Tested by

no test coverage detected