(model_kvargs, batch_size, image_size)
| 34 | |
| 35 | |
| 36 | def tppart_model_infer(model_kvargs, batch_size, image_size): |
| 37 | import torch |
| 38 | import torch.distributed as dist |
| 39 | |
| 40 | rank_id = model_kvargs["tp_rank_id"] |
| 41 | init_vision_distributed_env(model_kvargs) |
| 42 | |
| 43 | torch.cuda.empty_cache() |
| 44 | model_part = VisionTransformer(model_kvargs) |
| 45 | test_data = torch.randn((batch_size, 3, image_size, image_size)).cuda().to(torch.bfloat16) |
| 46 | # warm up |
| 47 | torch.cuda.synchronize() |
| 48 | for i in range(10): |
| 49 | model_part.forward(test_data) |
| 50 | torch.cuda.synchronize() |
| 51 | |
| 52 | torch.cuda.synchronize() |
| 53 | start_time = time.time() |
| 54 | for i in range(50): |
| 55 | model_part.forward(test_data) |
| 56 | torch.cuda.synchronize() |
| 57 | end_time = time.time() |
| 58 | if rank_id == 0: |
| 59 | print("time total cost(ms):", (end_time - start_time) / 50 * 1000) |
| 60 | print("image per second:", batch_size * 50 / (end_time - start_time)) |
| 61 | |
| 62 | return |
| 63 | |
| 64 | |
| 65 | if __name__ == "__main__": |
nothing calls this directly
no test coverage detected