MCPcopy
hub / github.com/cornellius-gp/gpytorch / cuda

Method cuda

gpytorch/utils/nearest_neighbors.py:68–75  ·  view source on GitHub ↗
(self, device=None)

Source from the content-addressed store, hash-verified

66 self.to(device)
67
68 def cuda(self, device=None):
69 super().cuda(device=device)
70 if self.nnlib == "faiss":
71 from faiss import GpuIndexFlatL2, StandardGpuResources
72
73 self.res = StandardGpuResources()
74 self.index = [GpuIndexFlatL2(self.res, self.dim) for _ in range(self.batch_shape.numel())]
75 return self
76
77 def cpu(self):
78 super().cpu()

Calls

no outgoing calls