MCPcopy
hub / github.com/MaartenGr/BERTopic / validate_distance_matrix

Function validate_distance_matrix

bertopic/_utils.py:106–152  ·  view source on GitHub ↗

Validate the distance matrix and convert it to a condensed distance matrix if necessary. A valid distance matrix is either a square matrix of shape (n_samples, n_samples) with zeros on the diagonal and non-negative values or condensed distance matrix of shape (n_samples * (n_samples

(X, n_samples)

Source from the content-addressed store, hash-verified

104
105
106def validate_distance_matrix(X, n_samples):
107 """Validate the distance matrix and convert it to a condensed distance matrix
108 if necessary.
109
110 A valid distance matrix is either a square matrix of shape (n_samples, n_samples)
111 with zeros on the diagonal and non-negative values or condensed distance matrix
112 of shape (n_samples * (n_samples - 1) / 2,) containing the upper triangular of the
113 distance matrix.
114
115 Arguments:
116 X: Distance matrix to validate.
117 n_samples: Number of samples in the dataset.
118
119 Returns:
120 X: Validated distance matrix.
121
122 Raises:
123 ValueError: If the distance matrix is not valid.
124 """
125 # Make sure it is the 1-D condensed distance matrix with zeros on the diagonal
126 s = X.shape
127 if len(s) == 1:
128 # check it has correct size
129 n = s[0]
130 if n != (n_samples * (n_samples - 1) / 2):
131 raise ValueError("The condensed distance matrix must have shape (n*(n-1)/2,).")
132 elif len(s) == 2:
133 # check it has correct size
134 if (s[0] != n_samples) or (s[1] != n_samples):
135 raise ValueError("The distance matrix must be of shape (n, n) where n is the number of samples.")
136 # force zero diagonal and convert to condensed
137 np.fill_diagonal(X, 0)
138 X = squareform(X)
139 else:
140 raise ValueError(
141 "The distance matrix must be either a 1-D condensed "
142 "distance matrix of shape (n*(n-1)/2,) or a "
143 "2-D square distance matrix of shape (n, n)."
144 "where n is the number of documents."
145 "Got a distance matrix of shape %s" % str(s)
146 )
147
148 # Make sure its entries are non-negative
149 if np.any(X < 0):
150 raise ValueError("Distance matrix cannot contain negative values.")
151
152 return X
153
154
155def get_unique_distances(dists: np.array, noise_max=1e-7) -> np.array:

Callers 3

hierarchical_topicsMethod · 0.90
visualize_hierarchyFunction · 0.90
_get_annotationsFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected