MCPcopy Index your code
hub / github.com/THUDM/GLM / BeamSearchScorer

Class BeamSearchScorer

generation_utils.py:141–341  ·  view source on GitHub ↗

r""" :class:`transformers.BeamScorer` implementing standard beam search decoding. Adapted in part from `Facebook's XLM beam search code `__. Args: batch_size

Source from the content-addressed store, hash-verified

139
140
141class BeamSearchScorer(BeamScorer):
142 r"""
143 :class:`transformers.BeamScorer` implementing standard beam search decoding.
144
145 Adapted in part from `Facebook's XLM beam search code
146 <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
147
148 Args:
149 batch_size (:obj:`int`):
150 Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel.
151 max_length (:obj:`int`):
152 The maximum length of the sequence to be generated.
153 num_beams (:obj:`int`):
154 Number of beams for beam search.
155 device (:obj:`torch.device`):
156 Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of
157 :obj:`BeamSearchScorer` will be allocated.
158 length_penalty (:obj:`float`, `optional`, defaults to 1.0):
159 Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
160 model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
161 sequences.
162 do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
163 Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
164 num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1):
165 The number of beam hypotheses that shall be returned upon calling
166 :meth:`~transformer.BeamSearchScorer.finalize`.
167 """
168
169 def __init__(
170 self,
171 batch_size: int,
172 max_length: int,
173 num_beams: int,
174 device: torch.device,
175 length_penalty: Optional[float] = 1.0,
176 do_early_stopping: Optional[bool] = False,
177 num_beam_hyps_to_keep: Optional[int] = 1,
178 ):
179 self.max_length = max_length
180 self.num_beams = num_beams
181 self.device = device
182 self.length_penalty = length_penalty
183 self.do_early_stopping = do_early_stopping
184 self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
185
186 self._is_init = False
187 self._beam_hyps = [
188 BeamHypotheses(
189 num_beams=self.num_beams,
190 max_length=self.max_length,
191 length_penalty=self.length_penalty,
192 early_stopping=self.do_early_stopping,
193 )
194 for _ in range(batch_size)
195 ]
196 self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
197
198 # if not isinstance(num_beams, int) or num_beams <= 1:

Callers 2

sample_sequenceFunction · 0.90
evaluateMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected