| 197 | self.data_args = data_args |
| 198 | |
| 199 | def prediction_step( |
| 200 | self, |
| 201 | model, |
| 202 | inputs, |
| 203 | prediction_loss_only: bool, |
| 204 | ignore_keys=None, |
| 205 | ): |
| 206 | if prediction_loss_only or self.args.pipeline_parallel_degree > 1: |
| 207 | return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) |
| 208 | elif not self.do_generation: |
| 209 | loss, logits, labels = super().prediction_step(model, inputs, prediction_loss_only, ignore_keys) |
| 210 | # argmax here to avoid gather all logits, which is too memory-consuming. |
| 211 | # keepdim in order to maintain the same shape as logits |
| 212 | if isinstance(logits, (list, tuple)): |
| 213 | logits = logits[0] |
| 214 | # all gather logits when enabling tensor_parallel_output |
| 215 | if self.args.tensor_parallel_degree > 1 and getattr(self.args, "tensor_parallel_output", False): |
| 216 | hcg = fleet.get_hybrid_communicate_group() |
| 217 | model_parallel_group = hcg.get_model_parallel_group() |
| 218 | gathered_logits = [] |
| 219 | dist.all_gather(gathered_logits, logits, group=model_parallel_group) |
| 220 | logits = paddle.concat(gathered_logits, axis=-1) |
| 221 | return (loss, logits.argmax(axis=-1, keepdim=True), labels) |
| 222 | |
| 223 | loss = None |
| 224 | |
| 225 | model.eval() |
| 226 | with paddle.no_grad(): |
| 227 | generated_tokens = model.generate( |
| 228 | input_ids=inputs["input_ids"], |
| 229 | attention_mask=inputs["attention_mask"] if "attention_mask" in inputs else None, |
| 230 | position_ids=inputs["position_ids"] if "position_ids" in inputs else None, |
| 231 | max_length=max(self.data_args.max_length - inputs["input_ids"].shape[-1], 1), |
| 232 | decode_strategy="sampling", |
| 233 | top_k=self.gen_args.top_k, |
| 234 | top_p=self.gen_args.top_p, |
| 235 | bos_token_id=self.tokenizer.bos_token_id, |
| 236 | eos_token_id=self.tokenizer.eos_token_id, |
| 237 | pad_token_id=self.tokenizer.pad_token_id, |
| 238 | use_cache=True, |
| 239 | )[0] |
| 240 | all_preds = [] |
| 241 | for pred_tokens in generated_tokens: |
| 242 | pred_tokens = pred_tokens[pred_tokens != self.tokenizer.pad_token_id].tolist() |
| 243 | all_preds.append(pred_tokens) |
| 244 | max_pred_length = max([len(x) for x in all_preds]) |
| 245 | for index, preds in enumerate(all_preds): |
| 246 | all_preds[index] = preds + [-100] * (max_pred_length - len(preds)) |
| 247 | all_preds = paddle.to_tensor(all_preds) |
| 248 | |
| 249 | if "labels" in inputs: |
| 250 | all_labels = paddle.to_tensor(inputs["labels"]) |
| 251 | else: |
| 252 | all_labels = None |
| 253 | |
| 254 | return (loss, all_preds, all_labels) |
| 255 | |
| 256 | def log(self, logs: Dict[str, float], **kwargs) -> None: |