MCPcopy
hub / github.com/deepspeedai/DeepSpeed / MPS_Accelerator

Class MPS_Accelerator

accelerator/mps_accelerator.py:18–279  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

16
17
18class MPS_Accelerator(DeepSpeedAccelerator):
19
20 def __init__(self):
21 self._name = "mps"
22 self._communication_backend_name = None
23 self._compile_backend = "inductor"
24
25 def is_synchronized_device(self):
26 return False
27
28 def use_host_timers(self):
29 # Event timers are not supported on MPS
30 return True
31
32 def resolves_data_dependency(self):
33 return self.is_synchronized_device()
34
35 def handles_memory_backpressure(self):
36 return self.is_synchronized_device()
37
38 # Device APIs
39 def device_name(self, device_index=None):
40 if device_index is None:
41 return "mps"
42 return "mps:{}".format(device_index)
43
44 def device(self, device_index):
45 return torch.device("mps", index=0)
46
47 def set_device(self, device_index):
48 return
49
50 def current_device(self):
51 return torch.device("mps", index=0)
52
53 def current_device_name(self):
54 return "mps:0"
55
56 def device_count(self):
57 return 1
58
59 def synchronize(self, device_index=None):
60 return torch.mps.synchronize()
61
62 # RNG APIs
63 def random(self):
64 return torch.random
65
66 def set_rng_state(self, new_state, device_index=None):
67 return torch.mps.set_rng_state(new_state)
68
69 def get_rng_state(self, device_index=None):
70 return torch.mps.get_rng_state()
71
72 def manual_seed(self, seed):
73 return torch.mps.manual_seed(seed)
74
75 def manual_seed_all(self, seed):

Callers 1

get_acceleratorFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…