MCPcopy Index your code
hub / github.com/huggingface/diffusers / _context_parallel_worker

Function _context_parallel_worker

tests/models/testing_utils/parallelism.py:53–110  ·  view source on GitHub ↗

Worker function for context parallel testing.

(
    rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict, attention_backend=None
)

Source from the content-addressed store, hash-verified

51
52
53def _context_parallel_worker(
54 rank, world_size, master_port, model_class, init_dict, cp_dict, inputs_dict, return_dict, attention_backend=None
55):
56 """Worker function for context parallel testing."""
57 try:
58 # Set up distributed environment
59 os.environ["MASTER_ADDR"] = "localhost"
60 os.environ["MASTER_PORT"] = str(master_port)
61 os.environ["RANK"] = str(rank)
62 os.environ["WORLD_SIZE"] = str(world_size)
63
64 # Get device configuration
65 device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
66 backend = device_config["backend"]
67 device_module = device_config["module"]
68
69 # Initialize process group
70 dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
71
72 # Set device for this process
73 device_module.set_device(rank)
74 device = torch.device(f"{torch_device}:{rank}")
75
76 # Create model
77 model = model_class(**init_dict)
78 model.to(device)
79 model.eval()
80
81 # Cast as needed.
82 model, inputs_dict = _maybe_cast_to_bf16(attention_backend, model, inputs_dict)
83
84 # Move inputs to device
85 inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
86
87 # Enable attention backend
88 if attention_backend:
89 model.set_attention_backend(attention_backend)
90
91 # Enable context parallelism
92 cp_config = ContextParallelConfig(**cp_dict)
93 model.enable_parallelism(config=cp_config)
94
95 # Run forward pass
96 with torch.no_grad():
97 output = model(**inputs_on_device, return_dict=False)[0]
98
99 # Only rank 0 reports results
100 if rank == 0:
101 return_dict["status"] = "success"
102 return_dict["output_shape"] = list(output.shape)
103
104 except Exception as e:
105 if rank == 0:
106 return_dict["status"] = "error"
107 return_dict["error"] = str(e)
108 finally:
109 if dist.is_initialized():
110 dist.destroy_process_group()

Callers

nothing calls this directly

Calls 7

_maybe_cast_to_bf16Function · 0.85
enable_parallelismMethod · 0.80
getMethod · 0.45
deviceMethod · 0.45
toMethod · 0.45
set_attention_backendMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…