()
| 53 | |
| 54 | |
| 55 | def test(): |
| 56 | # construct model |
| 57 | ray.init() |
| 58 | |
| 59 | # create 2 workers, each hold a GPU |
| 60 | resource_pool = RayResourcePool([2], use_gpu=True, name_prefix='a') |
| 61 | |
| 62 | class_with_args = RayClassWithInitArgs(cls=ModelActor) |
| 63 | shard_wg = RayWorkerGroup(resource_pool, class_with_args) |
| 64 | |
| 65 | test_bs = 8 |
| 66 | test_proto = DataProto(TensorDict({ |
| 67 | "sequence_ids": torch.ones([test_bs, 2048], dtype=torch.int64), |
| 68 | }, |
| 69 | batch_size=test_bs), |
| 70 | meta_info={"query_length": 1536}) |
| 71 | |
| 72 | # Sharding among different ranks |
| 73 | ret_proto1 = shard_wg.execute_with_func_generator(get_aux_metrics, test_proto) |
| 74 | |
| 75 | # compare execute on driver |
| 76 | hs = HackSelf() |
| 77 | ret_proto2 = get_aux_metrics(hs, test_proto) |
| 78 | |
| 79 | torch.testing.assert_close(ret_proto1.batch["decode_count"], ret_proto2.batch["decode_count"]) |
| 80 | |
| 81 | ray.shutdown() |
nothing calls this directly
no test coverage detected