r""" :class:`transformers.BeamScorer` implementing standard beam search decoding. Adapted in part from `Facebook's XLM beam search code `__. Args: batch_size
| 139 | |
| 140 | |
| 141 | class BeamSearchScorer(BeamScorer): |
| 142 | r""" |
| 143 | :class:`transformers.BeamScorer` implementing standard beam search decoding. |
| 144 | |
| 145 | Adapted in part from `Facebook's XLM beam search code |
| 146 | <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__. |
| 147 | |
| 148 | Args: |
| 149 | batch_size (:obj:`int`): |
| 150 | Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel. |
| 151 | max_length (:obj:`int`): |
| 152 | The maximum length of the sequence to be generated. |
| 153 | num_beams (:obj:`int`): |
| 154 | Number of beams for beam search. |
| 155 | device (:obj:`torch.device`): |
| 156 | Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of |
| 157 | :obj:`BeamSearchScorer` will be allocated. |
| 158 | length_penalty (:obj:`float`, `optional`, defaults to 1.0): |
| 159 | Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the |
| 160 | model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer |
| 161 | sequences. |
| 162 | do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`): |
| 163 | Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not. |
| 164 | num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1): |
| 165 | The number of beam hypotheses that shall be returned upon calling |
| 166 | :meth:`~transformer.BeamSearchScorer.finalize`. |
| 167 | """ |
| 168 | |
| 169 | def __init__( |
| 170 | self, |
| 171 | batch_size: int, |
| 172 | max_length: int, |
| 173 | num_beams: int, |
| 174 | device: torch.device, |
| 175 | length_penalty: Optional[float] = 1.0, |
| 176 | do_early_stopping: Optional[bool] = False, |
| 177 | num_beam_hyps_to_keep: Optional[int] = 1, |
| 178 | ): |
| 179 | self.max_length = max_length |
| 180 | self.num_beams = num_beams |
| 181 | self.device = device |
| 182 | self.length_penalty = length_penalty |
| 183 | self.do_early_stopping = do_early_stopping |
| 184 | self.num_beam_hyps_to_keep = num_beam_hyps_to_keep |
| 185 | |
| 186 | self._is_init = False |
| 187 | self._beam_hyps = [ |
| 188 | BeamHypotheses( |
| 189 | num_beams=self.num_beams, |
| 190 | max_length=self.max_length, |
| 191 | length_penalty=self.length_penalty, |
| 192 | early_stopping=self.do_early_stopping, |
| 193 | ) |
| 194 | for _ in range(batch_size) |
| 195 | ] |
| 196 | self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device) |
| 197 | |
| 198 | # if not isinstance(num_beams, int) or num_beams <= 1: |
no outgoing calls
no test coverage detected