Choose different methods to generate sentences. :param input_str: The input text. :param tokenizer: Tokenizer. :param method: Generation method. Should be one of: ['greedy', 'sample'] :param kwargs: Other parameters used for generation. - max_gen_len: int. Maximum gener
(moss, input_str, tokenizer, method, **kwargs)
| 1 | import jittor as jt |
| 2 | |
| 3 | def generate(moss, input_str, tokenizer, method, **kwargs): |
| 4 | """ |
| 5 | Choose different methods to generate sentences. |
| 6 | |
| 7 | :param input_str: The input text. |
| 8 | :param tokenizer: Tokenizer. |
| 9 | :param method: Generation method. Should be one of: ['greedy', 'sample'] |
| 10 | :param kwargs: Other parameters used for generation. |
| 11 | - max_gen_len: int. Maximum generate length. Used in all methods. |
| 12 | - temperature: float. Used in ``sample``. |
| 13 | - top_p: float. Used in ``sample``. |
| 14 | - top_k: int. Used in ``sample``. |
| 15 | """ |
| 16 | if method == "greedy": |
| 17 | return greedy_search(moss, input_str, tokenizer, **kwargs) |
| 18 | elif method == "sample": |
| 19 | return sample(moss, input_str, tokenizer, **kwargs) |
| 20 | else: |
| 21 | raise NotImplementedError( |
| 22 | f"Unsupported generation method {method}" |
| 23 | ) |
| 24 | |
| 25 | def greedy_search(model, input_str, tokenizer, max_gen_len, |
| 26 | eos_token_id=None, pad_token_id=None): |
no test coverage detected