(model, args, raw_text, seq, output_path)
| 244 | save_image(imgs, output_file, normalize=True) |
| 245 | |
| 246 | def post_selection(model, args, raw_text, seq, output_path): |
| 247 | tokenizer = get_tokenizer() |
| 248 | model.eval() |
| 249 | if not os.path.exists(output_path): |
| 250 | os.makedirs(output_path) |
| 251 | with torch.no_grad(): |
| 252 | start_time = time.time() |
| 253 | |
| 254 | num = seq.shape[0] |
| 255 | mbz = args.max_inference_batch_size |
| 256 | assert num < mbz or num % mbz == 0 |
| 257 | scores = [inverse_prompt_score(model, seq[tim*mbz:(tim+1)*mbz], args) |
| 258 | for tim in range(max(num // mbz, 1)) |
| 259 | ] |
| 260 | scores = torch.cat(scores, dim=0) |
| 261 | # scores = inverse_prompt_score(model, seq, args) # once |
| 262 | |
| 263 | print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True) |
| 264 | print("\nContext:", raw_text, flush=True) |
| 265 | rank = dist.get_rank() |
| 266 | output_file = os.path.join(output_path, f"scores_rank_{rank}.txt") |
| 267 | with open(output_file, 'a') as fout: |
| 268 | fout.write(raw_text+'\n') |
| 269 | fout.write('\t'.join([str(x) for x in scores.tolist()])+'\n') |
| 270 | print("\nSave to: ", output_file, flush=True) |
| 271 | |
| 272 | |
| 273 |
no test coverage detected