MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / generate_direct

Method generate_direct

mGPT/archs/mgpt_lm.py:232–283  ·  view source on GitHub ↗
(self,
                        texts: List[str],
                        max_length: int = 256,
                        num_beams: int = 1,
                        do_sample: bool = True,
                        bad_words_ids: List[int] = None)

Source from the content-addressed store, hash-verified

230 return outputs
231
232 def generate_direct(self,
233 texts: List[str],
234 max_length: int = 256,
235 num_beams: int = 1,
236 do_sample: bool = True,
237 bad_words_ids: List[int] = None):
238
239 # Device
240 self.device = self.language_model.device
241
242 # Tokenize
243 if self.lm_type == 'dec':
244 texts = [text + " \n " for text in texts]
245
246 source_encoding = self.tokenizer(texts,
247 padding='max_length',
248 max_length=self.max_length,
249 truncation=True,
250 return_attention_mask=True,
251 add_special_tokens=True,
252 return_tensors="pt")
253
254 source_input_ids = source_encoding.input_ids.to(self.device)
255 source_attention_mask = source_encoding.attention_mask.to(self.device)
256
257 if self.lm_type == 'encdec':
258 outputs = self.language_model.generate(
259 source_input_ids,
260 max_length=max_length,
261 num_beams=num_beams,
262 do_sample=do_sample,
263 bad_words_ids=bad_words_ids,
264 )
265 elif self.lm_type == 'dec':
266 outputs = self.language_model.generate(
267 input_ids=source_input_ids,
268 attention_mask=source_attention_mask,
269 pad_token_id=self.tokenizer.pad_token_id,
270 do_sample=do_sample,
271 max_new_tokens=max_length)
272 self.tokenizer.padding_side = 'left'
273
274 outputs_string = self.tokenizer.batch_decode(outputs,
275 skip_special_tokens=True)
276
277 print(texts[:2])
278 print(outputs_string[:2])
279
280 outputs_tokens, cleaned_text = self.motion_string_to_token(
281 outputs_string)
282
283 return outputs_tokens, cleaned_text
284
285 def generate_conditional(self,
286 texts: Optional[List[str]] = None,

Callers 2

generate_conditionalMethod · 0.95
forwardMethod · 0.80

Calls 2

toMethod · 0.45

Tested by

no test coverage detected