MCPcopy
hub / github.com/zai-org/CogView / post_selection

Function post_selection

generate_samples.py:246–270  ·  view source on GitHub ↗
(model, args, raw_text, seq, output_path)

Source from the content-addressed store, hash-verified

244 save_image(imgs, output_file, normalize=True)
245
246def post_selection(model, args, raw_text, seq, output_path):
247 tokenizer = get_tokenizer()
248 model.eval()
249 if not os.path.exists(output_path):
250 os.makedirs(output_path)
251 with torch.no_grad():
252 start_time = time.time()
253
254 num = seq.shape[0]
255 mbz = args.max_inference_batch_size
256 assert num < mbz or num % mbz == 0
257 scores = [inverse_prompt_score(model, seq[tim*mbz:(tim+1)*mbz], args)
258 for tim in range(max(num // mbz, 1))
259 ]
260 scores = torch.cat(scores, dim=0)
261 # scores = inverse_prompt_score(model, seq, args) # once
262
263 print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
264 print("\nContext:", raw_text, flush=True)
265 rank = dist.get_rank()
266 output_file = os.path.join(output_path, f"scores_rank_{rank}.txt")
267 with open(output_file, 'a') as fout:
268 fout.write(raw_text+'\n')
269 fout.write('\t'.join([str(x) for x in scores.tolist()])+'\n')
270 print("\nSave to: ", output_file, flush=True)
271
272
273

Callers 1

Calls 2

get_tokenizerFunction · 0.90
inverse_prompt_scoreFunction · 0.90

Tested by

no test coverage detected