MCPcopy
hub / github.com/deepspeedai/DeepSpeed / SDAA_Accelerator

Class SDAA_Accelerator

accelerator/sdaa_accelerator.py:45–323  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

43
44
45class SDAA_Accelerator(DeepSpeedAccelerator):
46
47 def __init__(self):
48 self._name = 'sdaa'
49 self._communication_backend_name = 'tccl'
50 self._compile_backend = "inductor"
51 self.class_dict = None
52
53 def is_synchronized_device(self):
54 return False
55
56 def use_host_timers(self):
57 return self.is_synchronized_device()
58
59 def resolves_data_dependency(self):
60 return self.is_synchronized_device()
61
62 def handles_memory_backpressure(self):
63 return self.is_synchronized_device()
64
65 # Device APIs
66 def device_name(self, device_index=None):
67 if device_index is None:
68 return 'sdaa'
69 return 'sdaa:{}'.format(device_index)
70
71 def device(self, device_index=None):
72 return torch.sdaa.device(device_index)
73
74 def set_device(self, device_index):
75 torch.sdaa.set_device(device_index)
76
77 def current_device(self):
78 return torch.sdaa.current_device()
79
80 def current_device_name(self):
81 return 'sdaa:{}'.format(torch.sdaa.current_device())
82
83 def device_count(self):
84 return torch.sdaa.device_count()
85
86 def synchronize(self, device_index=None):
87 return torch.sdaa.synchronize(device_index)
88
89 # RNG APIs
90 def random(self):
91 return torch.random
92
93 def set_rng_state(self, new_state, device_index=None):
94 if device_index is None:
95 return torch.sdaa.set_rng_state(new_state)
96
97 return torch.sdaa.set_rng_state(new_state, device_index)
98
99 def get_rng_state(self, device_index=None):
100 if device_index is None:
101 return torch.sdaa.get_rng_state()
102

Callers 1

get_acceleratorFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…