MCPcopy
hub / github.com/erikwijmans/Pointnet2_PyTorch / pdist2

Function pdist2

pointnet2/utils/linalg_utils.py:7–58  ·  view source on GitHub ↗

r""" Calculates the pairwise distance between X and Z D[b, i, j] = l2 distance X[b, i] and Z[b, j] Parameters --------- X : torch.Tensor X is a (B, N, d) tensor. There are B batches, and N vectors of dimension d Z: torch.Tensor Z is a (B, M, d) tensor. If Z is

(
        X: torch.Tensor,
        Z: torch.Tensor = None,
        order: PDist2Order = PDist2Order.d_second
)

Source from the content-addressed store, hash-verified

5
6
7def pdist2(
8 X: torch.Tensor,
9 Z: torch.Tensor = None,
10 order: PDist2Order = PDist2Order.d_second
11) -> torch.Tensor:
12 r""" Calculates the pairwise distance between X and Z
13
14 D[b, i, j] = l2 distance X[b, i] and Z[b, j]
15
16 Parameters
17 ---------
18 X : torch.Tensor
19 X is a (B, N, d) tensor. There are B batches, and N vectors of dimension d
20 Z: torch.Tensor
21 Z is a (B, M, d) tensor. If Z is None, then Z = X
22
23 Returns
24 -------
25 torch.Tensor
26 Distance matrix is size (B, N, M)
27 """
28
29 if order == PDist2Order.d_second:
30 if X.dim() == 2:
31 X = X.unsqueeze(0)
32 if Z is None:
33 Z = X
34 G = X @ Z.transpose(-2, -1)
35 S = (X * X).sum(-1, keepdim=True)
36 R = S.transpose(-2, -1)
37 else:
38 if Z.dim() == 2:
39 Z = Z.unsqueeze(0)
40 G = X @ Z.transpose(-2, -1)
41 S = (X * X).sum(-1, keepdim=True)
42 R = (Z * Z).sum(-1, keepdim=True).transpose(-2, -1)
43 else:
44 if X.dim() == 2:
45 X = X.unsqueeze(0)
46 if Z is None:
47 Z = X
48 G = X.transpose(-2, -1) @ Z
49 R = (X * X).sum(-2, keepdim=True)
50 S = R.transpose(-2, -1)
51 else:
52 if Z.dim() == 2:
53 Z = Z.unsqueeze(0)
54 G = X.transpose(-2, -1) @ Z
55 S = (X * X).sum(-2, keepdim=True).transpose(-2, -1)
56 R = (Z * Z).sum(-2, keepdim=True)
57
58 return torch.abs(R + S - 2 * G).squeeze(0)
59
60
61def pdist2_slow(X, Z=None):

Callers 1

linalg_utils.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected