Worker function for context parallel testing.
(
rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict, attention_backend=None
)
| 51 | |
| 52 | |
| 53 | def _context_parallel_worker( |
| 54 | rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict, attention_backend=None |
| 55 | ): |
| 56 | """Worker function for context parallel testing.""" |
| 57 | try: |
| 58 | # Set up distributed environment |
| 59 | os.environ["MASTER_ADDR"] = "localhost" |
| 60 | os.environ["MASTER_PORT"] = str(master_port) |
| 61 | os.environ["RANK"] = str(rank) |
| 62 | os.environ["WORLD_SIZE"] = str(world_size) |
| 63 | |
| 64 | # Get device configuration |
| 65 | device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"]) |
| 66 | backend = device_config["backend"] |
| 67 | device_module = device_config["module"] |
| 68 | |
| 69 | # Initialize process group |
| 70 | dist.init_process_group(backend=backend, rank=rank, world_size=world_size) |
| 71 | |
| 72 | # Set device for this process |
| 73 | device_module.set_device(rank) |
| 74 | device = torch.device(f"{torch_device}:{rank}") |
| 75 | |
| 76 | # Create model |
| 77 | model = model_class(**init_dict) |
| 78 | model.to(device) |
| 79 | model.eval() |
| 80 | |
| 81 | # Cast as needed. |
| 82 | model, inputs_dict = _maybe_cast_to_bf16(attention_backend, model, inputs_dict) |
| 83 | |
| 84 | # Move inputs to device |
| 85 | inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()} |
| 86 | |
| 87 | # Enable attention backend |
| 88 | if attention_backend: |
| 89 | model.set_attention_backend(attention_backend) |
| 90 | |
| 91 | # Enable context parallelism |
| 92 | cp_config = ContextParallelConfig(**cp_dict) |
| 93 | model.enable_parallelism(config=cp_config) |
| 94 | |
| 95 | # Run forward pass |
| 96 | with torch.no_grad(): |
| 97 | output = model(**inputs_on_device, return_dict=False)[0] |
| 98 | |
| 99 | # Only rank 0 reports results |
| 100 | if rank == 0: |
| 101 | return_dict["status"] = "success" |
| 102 | return_dict["output_shape"] = list(output.shape) |
| 103 | |
| 104 | except Exception as e: |
| 105 | if rank == 0: |
| 106 | return_dict["status"] = "error" |
| 107 | return_dict["error"] = str(e) |
| 108 | finally: |
| 109 | if dist.is_initialized(): |
| 110 | dist.destroy_process_group() |
nothing calls this directly
no test coverage detected
searching dependent graphs…