(self, config: PredictorArgument)
| 803 | self.used_list = [[] for _ in range(config.batch_size)] |
| 804 | |
| 805 | def init_inputs(self, config: PredictorArgument): |
| 806 | self.inputs = {} |
| 807 | |
| 808 | if config.export_precache: |
| 809 | self.inputs["src_mask"] = (self.pre_cache_mask - 1) * 1e4 |
| 810 | self.inputs["pre_ids"] = paddle.full([config.batch_size, self.total_max_length], -1, dtype="int64") |
| 811 | self.inputs["bad_tokens"] = paddle.to_tensor( |
| 812 | [ |
| 813 | -1, |
| 814 | ], |
| 815 | dtype="int64", |
| 816 | ) |
| 817 | self.inputs["penalty_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=1.0, dtype="float32") |
| 818 | self.inputs["frequency_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=0.0, dtype="float32") |
| 819 | self.inputs["presence_score"] = paddle.full(shape=[config.batch_size, 1], fill_value=0.0, dtype="float32") |
| 820 | |
| 821 | self.inputs["min_length"] = paddle.full( |
| 822 | shape=[config.batch_size, 1], fill_value=self.min_length, dtype="int64" |
| 823 | ) |
| 824 | self.inputs["max_length"] = paddle.full( |
| 825 | shape=[config.batch_size, 1], fill_value=config.max_length, dtype="int64" |
| 826 | ) |
| 827 | self.inputs["stop_nums"] = paddle.full(shape=[1], fill_value=config.batch_size, dtype="int64") |
| 828 | self.inputs["rope_emb"] = self._get_rotary_position_embedding( |
| 829 | paddle.arange(self.total_max_length).reshape((1, -1)), self.head_dim |
| 830 | ) |
| 831 | eos_token_id = get_eos_token_id(self.tokenizer, self.generation_config) |
| 832 | if isinstance(eos_token_id, int): |
| 833 | eos_token_id = [eos_token_id] |
| 834 | self.inputs["eos_token_id"] = paddle.to_tensor( |
| 835 | np.array(eos_token_id * config.batch_size).reshape(-1, 1).astype("int64") |
| 836 | ) |
| 837 | # bloom model needs src_mask and tgt_mask! |
| 838 | if "bloom" in self.architectures: |
| 839 | lower_one_tril = paddle.tril( |
| 840 | paddle.ones(shape=(self.total_max_length, self.total_max_length), dtype=self.dtype) |
| 841 | ) |
| 842 | lower_one_tril = lower_one_tril[None, None, :, :] |
| 843 | self.inputs["src_mask"] = lower_one_tril.tile([self.batch_size, 1, 1, 1]) |
| 844 | self.inputs["tgt_mask"] = paddle.full( |
| 845 | shape=[config.batch_size, 1, 1, self.total_max_length], fill_value=1, dtype=self.dtype |
| 846 | ) |
| 847 | arange_tensor_encoder = paddle.arange(self.total_max_length).astype(self.dtype) |
| 848 | alibi_slopes = get_alibi_slopes(self.num_attention_heads) |
| 849 | alibi = alibi_slopes[None, :, None, None] * arange_tensor_encoder |
| 850 | alibi_encoder = alibi.tile([self.batch_size, 1, self.total_max_length, 1]) |
| 851 | alibi_decoder = alibi.tile( |
| 852 | [ |
| 853 | self.batch_size, |
| 854 | 1, |
| 855 | 1, |
| 856 | 1, |
| 857 | ] |
| 858 | ) |
| 859 | # self.inputs["src_mask/tgt_mask"] is read only, will not be updated! |
| 860 | self.inputs["src_mask"] = ( |
| 861 | alibi_encoder + (1 - self.inputs["src_mask"]) * paddle.finfo(self.dtype).min |
| 862 | ).cast(self.dtype) |
no test coverage detected