MCPcopy
hub / github.com/ytongbai/LVM / generate_once

Method generate_once

evaluation/vqlm_demo/inference.py:103–159  ·  view source on GitHub ↗
(self, input_images, n_new_frames, temperature=1.0, top_p=1.0)

Source from the content-addressed store, hash-verified

101
102 @torch.no_grad()
103 def generate_once(self, input_images, n_new_frames, temperature=1.0, top_p=1.0):
104 assert type(input_images) == np.ndarray
105 with self.lock:
106 input_images = np.array(input_images, dtype=np.float32)
107 input_images = torch.tensor(
108 einops.rearrange(input_images, 'b h w c -> b c h w')
109 ).to(self.torch_device)
110
111 print('here:', type(input_images))
112
113 # old tokenizer
114 # input_ids = self.tokenizer.tokenize(input_images).view(1, -1)
115
116 # new tokenizer
117 _, input_ids = self.tokenizer.encode(input_images)
118 input_ids = input_ids.view(1, -1)
119
120
121 input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
122
123 new_tokens = []
124 current_context_frames = input_ids.shape[1] // 256
125 fisrt_generation_left = self.context_frames - current_context_frames
126 first_new_frames = min(fisrt_generation_left, n_new_frames)
127 input_ids = self.model.generate(
128 input_ids=input_ids,
129 attention_mask=torch.ones_like(input_ids),
130 pad_token_id=8192,
131 max_new_tokens=256 * first_new_frames,
132 do_sample=True,
133 top_p=top_p,
134 temperature=temperature,
135 suppress_tokens=list(range(8192, self.model.vocab_size)),
136 )
137 new_tokens.append(input_ids[:, -256 * first_new_frames:])
138 input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
139
140 for _ in range(max(0, n_new_frames - first_new_frames)):
141 input_ids = self.model.generate(
142 input_ids=input_ids,
143 attention_mask=torch.ones_like(input_ids),
144 pad_token_id=8192,
145 max_new_tokens=256,
146 do_sample=True,
147 top_p=top_p,
148 temperature=temperature,
149 suppress_tokens=list(range(8192, self.model.vocab_size)),
150 )
151 new_tokens.append(input_ids[:, -256:])
152 input_ids = input_ids[:, -(self.context_frames - 1) * 256:]
153
154 new_tokens = torch.cat(new_tokens, dim=1).view(-1, 256)
155 new_images = einops.rearrange(
156 torch.clamp(self.tokenizer.decode_code(new_tokens), 0.0, 1.0),
157 'b c h w -> b h w c'
158 ).detach().cpu().numpy()
159 return new_images
160

Callers 2

__call__Method · 0.95
generate_onceMethod · 0.45

Calls 3

encodeMethod · 0.80
decode_codeMethod · 0.80
generateMethod · 0.45

Tested by

no test coverage detected