MCPcopy
hub / github.com/zai-org/CogView / get_batch

Function get_batch

generation/sampling.py:51–62  ·  view source on GitHub ↗
(context_tokens, device, args)

Source from the content-addressed store, hash-verified

49 return logits
50
51def get_batch(context_tokens, device, args):
52 tokens = context_tokens
53 if len(tokens.shape) == 1:
54 tokens = tokens.unsqueeze(0).contiguous()
55 else:
56 tokens = tokens.view(tokens.shape[0], -1).contiguous()
57 tokens = tokens.to(device)
58
59 # Get the masks and postition ids.
60 attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
61 tokens)
62 return tokens, attention_mask, position_ids
63
64def filling_sequence(
65 model,

Callers 2

filling_sequenceFunction · 0.70
inverse_prompt_scoreFunction · 0.70

Calls 1

Tested by

no test coverage detected