MCPcopy
hub / github.com/hpcaitech/ColossalAI / test_tp_engine

Function test_tp_engine

tests/test_infer/test_rpc_engine.py:84–102  ·  view source on GitHub ↗
(prompt_template, do_sample)

Source from the content-addressed store, hash-verified

82@parameterize("do_sample", [False])
83@rerun_if_address_is_in_use()
84def test_tp_engine(prompt_template, do_sample):
85 if torch.multiprocessing.get_start_method(allow_none=True) is None:
86 torch.multiprocessing.set_start_method("spawn")
87 kwargs1 = {
88 "use_engine": True,
89 "prompt_template": prompt_template,
90 "do_sample": do_sample,
91 "policy": NoPaddingLlamaModelInferPolicy(),
92 }
93
94 kwargs2 = {"use_engine": False, "prompt_template": prompt_template, "do_sample": do_sample, "policy": None}
95
96 colossal_tp_1_output = run_engine(1, **kwargs1)
97 colossal_tp_2_output = run_engine(2, **kwargs1)
98 transformer_tp_1_output = run_engine(1, **kwargs2)
99
100 for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
101 assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
102 assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
103
104
105if __name__ == "__main__":

Callers 1

test_rpc_engine.pyFile · 0.70

Calls 2

run_engineFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…