(chat_format, tokenizer)
| 107 | |
| 108 | |
| 109 | def get_stop_words_ids(chat_format, tokenizer): |
| 110 | if chat_format == "raw": |
| 111 | stop_words_ids = [tokenizer.encode("Human:"), [tokenizer.eod_id]] |
| 112 | elif chat_format == "chatml": |
| 113 | stop_words_ids = [[tokenizer.im_end_id], [tokenizer.im_start_id]] |
| 114 | else: |
| 115 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") |
| 116 | return stop_words_ids |
| 117 | |
| 118 | |
| 119 | def make_context( |
no test coverage detected