| 66 | |
| 67 | |
| 68 | class SemanticSearch: |
| 69 | def __init__(self): |
| 70 | self.use = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4') |
| 71 | self.fitted = False |
| 72 | |
| 73 | def fit(self, data, batch=1000, n_neighbors=5): |
| 74 | self.data = data |
| 75 | self.embeddings = self.get_text_embedding(data, batch=batch) |
| 76 | n_neighbors = min(n_neighbors, len(self.embeddings)) |
| 77 | self.nn = NearestNeighbors(n_neighbors=n_neighbors) |
| 78 | self.nn.fit(self.embeddings) |
| 79 | self.fitted = True |
| 80 | |
| 81 | def __call__(self, text, return_data=True): |
| 82 | inp_emb = self.use([text]) |
| 83 | neighbors = self.nn.kneighbors(inp_emb, return_distance=False)[0] |
| 84 | |
| 85 | if return_data: |
| 86 | return [self.data[i] for i in neighbors] |
| 87 | else: |
| 88 | return neighbors |
| 89 | |
| 90 | def get_text_embedding(self, texts, batch=1000): |
| 91 | embeddings = [] |
| 92 | for i in range(0, len(texts), batch): |
| 93 | text_batch = texts[i : (i + batch)] |
| 94 | emb_batch = self.use(text_batch) |
| 95 | embeddings.append(emb_batch) |
| 96 | embeddings = np.vstack(embeddings) |
| 97 | return embeddings |
| 98 | |
| 99 | |
| 100 | def load_recommender(path, start_page=1): |