| 228 | |
| 229 | @pytest.fixture |
| 230 | def assert_fn(model_w_task): |
| 231 | model, task = model_w_task |
| 232 | assert_fn_dict = { |
| 233 | "fill-mask": fill_mask_assert, |
| 234 | "question-answering": question_answering_assert, |
| 235 | "text-classification": text_classification_assert, |
| 236 | "token-classification": token_classification_assert, |
| 237 | "text-generation": text_generation_assert, |
| 238 | "text2text-generation": text2text_generation_assert, |
| 239 | "translation": translation_assert, |
| 240 | "summarization": summarization_assert |
| 241 | } |
| 242 | assert_fn = assert_fn_dict.get(task, None) |
| 243 | if assert_fn is None: |
| 244 | NotImplementedError(f'assert_fn for task "{task}" is not implemented') |
| 245 | return assert_fn |
| 246 | |
| 247 | |
| 248 | # Used to verify DeepSpeed kernel injection worked with a model |