(prompt_template, do_sample)
| 82 | @parameterize("do_sample", [False]) |
| 83 | @rerun_if_address_is_in_use() |
| 84 | def test_tp_engine(prompt_template, do_sample): |
| 85 | if torch.multiprocessing.get_start_method(allow_none=True) is None: |
| 86 | torch.multiprocessing.set_start_method("spawn") |
| 87 | kwargs1 = { |
| 88 | "use_engine": True, |
| 89 | "prompt_template": prompt_template, |
| 90 | "do_sample": do_sample, |
| 91 | "policy": NoPaddingLlamaModelInferPolicy(), |
| 92 | } |
| 93 | |
| 94 | kwargs2 = {"use_engine": False, "prompt_template": prompt_template, "do_sample": do_sample, "policy": None} |
| 95 | |
| 96 | colossal_tp_1_output = run_engine(1, **kwargs1) |
| 97 | colossal_tp_2_output = run_engine(2, **kwargs1) |
| 98 | transformer_tp_1_output = run_engine(1, **kwargs2) |
| 99 | |
| 100 | for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output): |
| 101 | assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}" |
| 102 | assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}" |
| 103 | |
| 104 | |
| 105 | if __name__ == "__main__": |
no test coverage detected
searching dependent graphs…