(self,
texts: List[str],
max_length: int = 256,
num_beams: int = 1,
do_sample: bool = True,
bad_words_ids: List[int] = None)
| 230 | return outputs |
| 231 | |
| 232 | def generate_direct(self, |
| 233 | texts: List[str], |
| 234 | max_length: int = 256, |
| 235 | num_beams: int = 1, |
| 236 | do_sample: bool = True, |
| 237 | bad_words_ids: List[int] = None): |
| 238 | |
| 239 | # Device |
| 240 | self.device = self.language_model.device |
| 241 | |
| 242 | # Tokenize |
| 243 | if self.lm_type == 'dec': |
| 244 | texts = [text + " \n " for text in texts] |
| 245 | |
| 246 | source_encoding = self.tokenizer(texts, |
| 247 | padding='max_length', |
| 248 | max_length=self.max_length, |
| 249 | truncation=True, |
| 250 | return_attention_mask=True, |
| 251 | add_special_tokens=True, |
| 252 | return_tensors="pt") |
| 253 | |
| 254 | source_input_ids = source_encoding.input_ids.to(self.device) |
| 255 | source_attention_mask = source_encoding.attention_mask.to(self.device) |
| 256 | |
| 257 | if self.lm_type == 'encdec': |
| 258 | outputs = self.language_model.generate( |
| 259 | source_input_ids, |
| 260 | max_length=max_length, |
| 261 | num_beams=num_beams, |
| 262 | do_sample=do_sample, |
| 263 | bad_words_ids=bad_words_ids, |
| 264 | ) |
| 265 | elif self.lm_type == 'dec': |
| 266 | outputs = self.language_model.generate( |
| 267 | input_ids=source_input_ids, |
| 268 | attention_mask=source_attention_mask, |
| 269 | pad_token_id=self.tokenizer.pad_token_id, |
| 270 | do_sample=do_sample, |
| 271 | max_new_tokens=max_length) |
| 272 | self.tokenizer.padding_side = 'left' |
| 273 | |
| 274 | outputs_string = self.tokenizer.batch_decode(outputs, |
| 275 | skip_special_tokens=True) |
| 276 | |
| 277 | print(texts[:2]) |
| 278 | print(outputs_string[:2]) |
| 279 | |
| 280 | outputs_tokens, cleaned_text = self.motion_string_to_token( |
| 281 | outputs_string) |
| 282 | |
| 283 | return outputs_tokens, cleaned_text |
| 284 | |
| 285 | def generate_conditional(self, |
| 286 | texts: Optional[List[str]] = None, |
no test coverage detected