MCPcopy
hub / github.com/hkust-nlp/simpleRL-reason / test

Function test

tests/ray/test_driverfunc_to_worker.py:55–81  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

53
54
55def test():
56 # construct model
57 ray.init()
58
59 # create 2 workers, each hold a GPU
60 resource_pool = RayResourcePool([2], use_gpu=True, name_prefix='a')
61
62 class_with_args = RayClassWithInitArgs(cls=ModelActor)
63 shard_wg = RayWorkerGroup(resource_pool, class_with_args)
64
65 test_bs = 8
66 test_proto = DataProto(TensorDict({
67 "sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64),
68 },
69 batch_size=test_bs),
70 meta_info={"query_length": 1536})
71
72 # Sharding among different ranks
73 ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto)
74
75 # compare execute on driver
76 hs = HackSelf()
77 ret_proto2 = get_aux_metrics(hs, test_proto)
78
79 torch.testing.assert_close(ret_proto1.batch["decode_count"], ret_proto2.batch["decode_count"])
80
81 ray.shutdown()

Callers

nothing calls this directly

Calls 8

RayResourcePoolClass · 0.90
RayWorkerGroupClass · 0.90
DataProtoClass · 0.90
HackSelfClass · 0.85
get_aux_metricsFunction · 0.85
initMethod · 0.45

Tested by

no test coverage detected