MCPcopy
hub / github.com/deepspeedai/DeepSpeed / get_accelerator

Function get_accelerator

accelerator/real_accelerator.py:51–246  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

49
50
51def get_accelerator():
52 global ds_accelerator
53 if ds_accelerator is not None:
54 return ds_accelerator
55
56 accelerator_name = None
57 ds_set_method = None
58 # 1. Detect whether there is override of DeepSpeed accelerators from environment variable.
59 if "DS_ACCELERATOR" in os.environ.keys():
60 accelerator_name = os.environ["DS_ACCELERATOR"]
61 if accelerator_name == "xpu":
62 try:
63 import torch
64 assert hasattr(torch, 'xpu') and torch.xpu.is_available(), \
65 "XPU_Accelerator requires PyTorch with XPU support (torch.xpu)."
66 except (ImportError, AssertionError) as e:
67 raise ValueError(f"XPU_Accelerator requires PyTorch with XPU support: {e}")
68 elif accelerator_name == "cpu":
69 pass
70 elif accelerator_name == "npu":
71 try:
72 import torch_npu # noqa: F401 # type: ignore
73 except ImportError as e:
74 raise ValueError("NPU_Accelerator requires torch_npu, which is not installed on this system.")
75 pass
76 elif accelerator_name == "sdaa":
77 try:
78 import torch_sdaa # noqa: F401 # type: ignore
79 except ImportError as e:
80 raise ValueError("SDAA_Accelerator requires torch_sdaa, which is not installed on this system.")
81 pass
82 elif accelerator_name == "mps":
83 try:
84 import torch.mps
85
86 # should use torch.mps.is_available() if it exists someday but this is used as proxy
87 torch.mps.current_allocated_memory()
88 except (RuntimeError, ImportError) as e:
89 raise ValueError("MPS_Accelerator requires torch.mps, which is not installed on this system.")
90 elif accelerator_name == "hpu":
91 try:
92 import habana_frameworks.torch.hpu # noqa: F401
93 except ImportError as e:
94 raise ValueError(
95 "HPU_Accelerator requires habana_frameworks.torch.hpu, which is not installed on this system.")
96 elif accelerator_name == "mlu":
97 try:
98 import torch_mlu # noqa: F401
99 except ImportError as e:
100 raise ValueError("MLU_Accelerator requires torch_mlu, which is not installed on this system.")
101 elif accelerator_name == "supa":
102 try:
103 import torch_supa # noqa: F401 # type: ignore
104 except ImportError as e:
105 raise ValueError("SUPA_Accelerator requires torch_supa, which is not installed on this system.")
106 elif accelerator_name not in SUPPORTED_ACCELERATOR_LIST:
107 raise ValueError(f'DS_ACCELERATOR must be one of {SUPPORTED_ACCELERATOR_LIST}. '
108 f'Value "{accelerator_name}" is not supported')

Callers 15

setup.pyFile · 0.90
_deviceMethod · 0.90
_releaseMethod · 0.90
test_compile.pyFile · 0.90
__init__Method · 0.90
create_deepspeed_argsFunction · 0.90
set_accelerator_visibleFunction · 0.90
DistributedExecClass · 0.90
__call__Method · 0.90
_launch_procsMethod · 0.90

Calls 13

CUDA_AcceleratorClass · 0.85
CPU_AcceleratorClass · 0.85
XPU_AcceleratorClass · 0.85
NPU_AcceleratorClass · 0.85
SDAA_AcceleratorClass · 0.85
MPS_AcceleratorClass · 0.85
HPU_AcceleratorClass · 0.85
MLU_AcceleratorClass · 0.85
SUPA_AcceleratorClass · 0.85
_validate_acceleratorFunction · 0.85
warningMethod · 0.80
is_availableMethod · 0.45

Tested by 15

_deviceMethod · 0.72
_releaseMethod · 0.72
init_softmax_inputsFunction · 0.72
init_matmul_inputsFunction · 0.72
test_fused_lion_equalMethod · 0.72
get_q_propsFunction · 0.72
get_scale_zero_pointFunction · 0.72
test_float_quantizeFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…