MCPcopy
hub / github.com/FoundationVision/LlamaGen / infer

Function infer

app.py:53–90  ·  view source on GitHub ↗
(cfg_scale, top_k, top_p, temperature, class_label, seed)

Source from the content-addressed store, hash-verified

51
52
53def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
54 llm.llm_engine.model_executor.driver_worker.model_runner.model.sampler = Sampler(cfg_scale)
55 args.cfg_scale = cfg_scale
56 n = 4
57 latent_size = image_size // args.downsample_size
58 # Labels to condition the model with (feel free to change):
59 class_labels = [class_label for _ in range(n)]
60 qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
61
62 prompt_token_ids = [[cind] for cind in class_labels]
63 if cfg_scale > 1.0:
64 prompt_token_ids.extend([[args.num_classes] for _ in range(len(prompt_token_ids))])
65
66 # Create a sampling params object.
67 sampling_params = SamplingParams(
68 temperature=temperature, top_p=top_p, top_k=top_k,
69 max_tokens=latent_size ** 2)
70
71 t1 = time.time()
72 torch.manual_seed(seed)
73 outputs = llm.generate(
74 prompt_token_ids=prompt_token_ids,
75 sampling_params=sampling_params,
76 use_tqdm=False)
77 sampling_time = time.time() - t1
78 print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
79
80 index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device)
81 if cfg_scale > 1.0:
82 index_sample = index_sample[:len(class_labels)]
83 t2 = time.time()
84 samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
85 decoder_time = time.time() - t2
86 print(f"decoder takes about {decoder_time:.2f} seconds.")
87 # Convert to PIL.Image format:
88 samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy()
89 samples = [Image.fromarray(sample) for sample in samples]
90 return samples
91
92
93parser = argparse.ArgumentParser()

Callers

nothing calls this directly

Calls 4

SamplerClass · 0.90
printFunction · 0.85
generateMethod · 0.80
decode_codeMethod · 0.45

Tested by

no test coverage detected