| 48 | ) |
| 49 | |
| 50 | class CodeBlockStopper(StoppingCriteria): |
| 51 | def __call__( |
| 52 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs |
| 53 | ) -> bool: |
| 54 | # Code-completion is open-end generation. |
| 55 | # We check \n\n to stop at end of a code block. |
| 56 | if list(input_ids[0][-2:]) == [628, 198]: |
| 57 | return True |
| 58 | return False |
| 59 | |
| 60 | gen_kwargs = dict( |
| 61 | **encoding, |
no outgoing calls
no test coverage detected
searching dependent graphs…