MCPcopy
hub / github.com/zai-org/CogView / get_context

Function get_context

generate_samples.py:75–140  ·  view source on GitHub ↗
(args, query_template='{}')

Source from the content-addressed store, hash-verified

73 return seq
74
75def get_context(args, query_template='{}'):
76 tokenizer = get_tokenizer()
77 terminate_runs = 0
78 img_size = 256 if args.generation_task != 'low-level super-resolution' else 128
79 ml = max(args.max_position_embeddings, args.max_position_embeddings_finetune)
80 output_path = args.output_path
81
82 if args.input_source == 'interactive':
83 assert not args.with_id, '--with-id is only used with file inputs.'
84 if args.generation_task == 'post-selection':
85 raise ValueError('post-selection only takes file inputs!')
86 while True:
87 raw_text = input("\nPlease Input Query (stop to exit) >>> ")
88 if not raw_text:
89 print('Query should not be empty!')
90 continue
91 if raw_text == "stop":
92 return
93 try:
94 seq = _parse_and_to_tensor(raw_text, img_size=img_size, query_template=query_template)
95 except (ValueError, FileNotFoundError) as e:
96 print(e)
97 continue
98 if len(seq) > ml:
99 print("\nSeq length", len(seq),
100 f"\nPlease give smaller context than {ml}!")
101 continue
102 yield (raw_text, seq, output_path)
103 else:
104 with open(args.input_source, 'r') as fin:
105 inputs = fin.readlines()
106 for line_no, raw_text in enumerate(inputs):
107 if line_no % dist.get_world_size() != dist.get_rank():
108 continue
109 rk = dist.get_rank()
110 print(f'Working on No. {line_no} on {rk}... ')
111 raw_text = raw_text.strip()
112 if len(raw_text) == 0:
113 continue
114 if args.with_id: # with id
115 parts = raw_text.split('\t')
116 output_path = os.path.join(args.output_path, parts[0])
117 raw_text = '\t'.join(parts[1:])
118
119 if args.generation_task == 'post-selection':
120 parts = raw_text.split('\t')
121 seqs = []
122 for part in parts[1:]:
123 try:
124 seq_single = _parse_and_to_tensor('\t'.join([part, parts[0]]), img_size=img_size, query_template=query_template)
125 seqs.append(seq_single)
126 except (ValueError, FileNotFoundError) as e:
127 print(e)
128 continue
129 seq = torch.stack(seqs)
130 else:
131 try:
132 seq = _parse_and_to_tensor(raw_text, img_size=img_size, query_template=query_template)

Callers 1

Calls 2

get_tokenizerFunction · 0.90
_parse_and_to_tensorFunction · 0.85

Tested by

no test coverage detected