(self, input_images, n_new_frames, temperature=1.0, top_p=1.0)
| 101 | |
| 102 | @torch.no_grad() |
| 103 | def generate_once(self, input_images, n_new_frames, temperature=1.0, top_p=1.0): |
| 104 | assert type(input_images) == np.ndarray |
| 105 | with self.lock: |
| 106 | input_images = np.array(input_images, dtype=np.float32) |
| 107 | input_images = torch.tensor( |
| 108 | einops.rearrange(input_images, 'b h w c -> b c h w') |
| 109 | ).to(self.torch_device) |
| 110 | |
| 111 | print('here:', type(input_images)) |
| 112 | |
| 113 | # old tokenizer |
| 114 | # input_ids = self.tokenizer.tokenize(input_images).view(1, -1) |
| 115 | |
| 116 | # new tokenizer |
| 117 | _, input_ids = self.tokenizer.encode(input_images) |
| 118 | input_ids = input_ids.view(1, -1) |
| 119 | |
| 120 | |
| 121 | input_ids = input_ids[:, -(self.context_frames - 1) * 256:] |
| 122 | |
| 123 | new_tokens = [] |
| 124 | current_context_frames = input_ids.shape[1] // 256 |
| 125 | fisrt_generation_left = self.context_frames - current_context_frames |
| 126 | first_new_frames = min(fisrt_generation_left, n_new_frames) |
| 127 | input_ids = self.model.generate( |
| 128 | input_ids=input_ids, |
| 129 | attention_mask=torch.ones_like(input_ids), |
| 130 | pad_token_id=8192, |
| 131 | max_new_tokens=256 * first_new_frames, |
| 132 | do_sample=True, |
| 133 | top_p=top_p, |
| 134 | temperature=temperature, |
| 135 | suppress_tokens=list(range(8192, self.model.vocab_size)), |
| 136 | ) |
| 137 | new_tokens.append(input_ids[:, -256 * first_new_frames:]) |
| 138 | input_ids = input_ids[:, -(self.context_frames - 1) * 256:] |
| 139 | |
| 140 | for _ in range(max(0, n_new_frames - first_new_frames)): |
| 141 | input_ids = self.model.generate( |
| 142 | input_ids=input_ids, |
| 143 | attention_mask=torch.ones_like(input_ids), |
| 144 | pad_token_id=8192, |
| 145 | max_new_tokens=256, |
| 146 | do_sample=True, |
| 147 | top_p=top_p, |
| 148 | temperature=temperature, |
| 149 | suppress_tokens=list(range(8192, self.model.vocab_size)), |
| 150 | ) |
| 151 | new_tokens.append(input_ids[:, -256:]) |
| 152 | input_ids = input_ids[:, -(self.context_frames - 1) * 256:] |
| 153 | |
| 154 | new_tokens = torch.cat(new_tokens, dim=1).view(-1, 256) |
| 155 | new_images = einops.rearrange( |
| 156 | torch.clamp(self.tokenizer.decode_code(new_tokens), 0.0, 1.0), |
| 157 | 'b c h w -> b h w c' |
| 158 | ).detach().cpu().numpy() |
| 159 | return new_images |
| 160 |
no test coverage detected