MCPcopy
hub / github.com/zai-org/CogView / filling_sequence

Function filling_sequence

generation/sampling.py:64–186  ·  view source on GitHub ↗

seq: [2, 3, 5, ..., -1(to be generated), -N (N beams), -1] context_length: first non(-1)s

(
        model, 
        seq, 
        args, 
        mems=None, 
        invalid_slices=[], 
        **kwargs)

Source from the content-addressed store, hash-verified

62 return tokens, attention_mask, position_ids
63
64def filling_sequence(
65 model,
66 seq,
67 args,
68 mems=None,
69 invalid_slices=[],
70 **kwargs):
71 '''
72 seq: [2, 3, 5, ..., -1(to be generated), -N (N beams), -1]
73 context_length: first non(-1)s
74 '''
75 tokenizer = get_tokenizer()
76 device = seq.device
77 assert len(seq.shape) == 1
78 out_seq_length = len(seq)
79 # building the initial tokens, attention_mask, and position_ids
80 context_length = 0
81 offset = 100000
82
83 invalid_slices = [slice(0, tokenizer.img_tokenizer.num_tokens)]
84
85 while seq[context_length] >= 0:
86 # change what to generate
87 if seq[context_length] in [tokenizer['[BOI1]'], tokenizer['[BOI2]']]:
88 invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
89 elif seq[context_length] in [tokenizer['[EOI1]'], tokenizer['[EOI2]']]:
90 invalid_slices = [
91 slice(0, tokenizer.img_tokenizer.num_tokens),
92 slice(tokenizer.img_tokenizer.num_tokens + tokenizer.txt_tokenizer.num_tokens, None)]
93
94 if seq[context_length] == tokenizer['[ROI2]']:
95 offset = context_length
96 context_length += 1
97 tokens, attention_mask, position_ids = get_batch(seq[:context_length], device, args)
98
99 counter = context_length - 1 # == len(tokens) - 1
100 index = 0 # len(mems)
101 if mems is None:
102 mems = []
103 score = [0] # sum log likelihood for beams
104
105 if args.is_sparse == 2:
106 tokenizer = get_tokenizer()
107 img_txt_sep = tokenizer.img_tokenizer.num_tokens
108 img_indices_bool = (tokens < img_txt_sep)
109 txt_indices_bool = (~img_indices_bool)
110 elif args.is_sparse == 0:
111 txt_indices_bool = img_indices_bool = None
112 else:
113 raise ValueError('set is_sparse==2 for inference.')
114
115 while counter < (out_seq_length - 1):
116 # Now, we want to generate seq[counter + 1]
117 # token[:, index: counter+1] are just added.
118
119 if seq[counter + 1] in [tokenizer['[BOI1]'], tokenizer['[BOI2]']]:
120 invalid_slices = [slice(tokenizer.img_tokenizer.num_tokens, None)]
121 elif seq[counter + 1] in [tokenizer['[EOI1]'], tokenizer['[EOI2]']]:

Callers 2

generate_images_onceFunction · 0.90
magnifyFunction · 0.85

Calls 5

get_tokenizerFunction · 0.90
shrink_beamsFunction · 0.85
top_k_logitsFunction · 0.85
logMethod · 0.80
get_batchFunction · 0.70

Tested by

no test coverage detected