(cfg_scale, top_k, top_p, temperature, class_label, seed)
| 51 | |
| 52 | |
| 53 | def infer(cfg_scale, top_k, top_p, temperature, class_label, seed): |
| 54 | llm.llm_engine.model_executor.driver_worker.model_runner.model.sampler = Sampler(cfg_scale) |
| 55 | args.cfg_scale = cfg_scale |
| 56 | n = 4 |
| 57 | latent_size = image_size // args.downsample_size |
| 58 | # Labels to condition the model with (feel free to change): |
| 59 | class_labels = [class_label for _ in range(n)] |
| 60 | qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size] |
| 61 | |
| 62 | prompt_token_ids = [[cind] for cind in class_labels] |
| 63 | if cfg_scale > 1.0: |
| 64 | prompt_token_ids.extend([[args.num_classes] for _ in range(len(prompt_token_ids))]) |
| 65 | |
| 66 | # Create a sampling params object. |
| 67 | sampling_params = SamplingParams( |
| 68 | temperature=temperature, top_p=top_p, top_k=top_k, |
| 69 | max_tokens=latent_size ** 2) |
| 70 | |
| 71 | t1 = time.time() |
| 72 | torch.manual_seed(seed) |
| 73 | outputs = llm.generate( |
| 74 | prompt_token_ids=prompt_token_ids, |
| 75 | sampling_params=sampling_params, |
| 76 | use_tqdm=False) |
| 77 | sampling_time = time.time() - t1 |
| 78 | print(f"gpt sampling takes about {sampling_time:.2f} seconds.") |
| 79 | |
| 80 | index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device) |
| 81 | if cfg_scale > 1.0: |
| 82 | index_sample = index_sample[:len(class_labels)] |
| 83 | t2 = time.time() |
| 84 | samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1] |
| 85 | decoder_time = time.time() - t2 |
| 86 | print(f"decoder takes about {decoder_time:.2f} seconds.") |
| 87 | # Convert to PIL.Image format: |
| 88 | samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy() |
| 89 | samples = [Image.fromarray(sample) for sample in samples] |
| 90 | return samples |
| 91 | |
| 92 | |
| 93 | parser = argparse.ArgumentParser() |
nothing calls this directly
no test coverage detected