Return one vector *per* span. Args: text: Full document text. chunk_spans: List of (char_start, char_end) offsets for each chunk. Returns: List of numpy float32 arrays – one per chunk.
(self, text: str, chunk_spans: List[Tuple[int, int]])
| 41 | |
| 42 | @torch.inference_mode() |
| 43 | def encode(self, text: str, chunk_spans: List[Tuple[int, int]]) -> List[np.ndarray]: |
| 44 | """Return one vector *per* span. |
| 45 | |
| 46 | Args: |
| 47 | text: Full document text. |
| 48 | chunk_spans: List of (char_start, char_end) offsets for each chunk. |
| 49 | |
| 50 | Returns: |
| 51 | List of numpy float32 arrays – one per chunk. |
| 52 | """ |
| 53 | if not chunk_spans: |
| 54 | return [] |
| 55 | |
| 56 | # Tokenise and obtain per-token hidden states |
| 57 | inputs = self.tokenizer( |
| 58 | text, |
| 59 | return_tensors="pt", |
| 60 | return_offsets_mapping=True, |
| 61 | truncation=True, |
| 62 | max_length=self.max_len, |
| 63 | ) |
| 64 | inputs = {k: v.to(self.device) for k, v in inputs.items()} |
| 65 | offsets = inputs.pop("offset_mapping").squeeze(0).cpu().tolist() # (seq_len, 2) |
| 66 | |
| 67 | out = self.model(**inputs) |
| 68 | last_hidden = out.last_hidden_state.squeeze(0) # (seq_len, dim) |
| 69 | last_hidden = last_hidden.cpu() |
| 70 | |
| 71 | # For each chunk span, gather token indices belonging to it |
| 72 | vectors: List[np.ndarray] = [] |
| 73 | for start_char, end_char in chunk_spans: |
| 74 | token_indices = [i for i, (s, e) in enumerate(offsets) if s >= start_char and e <= end_char] |
| 75 | if not token_indices: |
| 76 | # Fallback: if tokenizer lost the span (e.g. due to trimming) just average CLS + SEP |
| 77 | token_indices = [0] |
| 78 | chunk_vec = last_hidden[token_indices].mean(dim=0).numpy().astype("float32") |
| 79 | |
| 80 | # Check for NaN or infinite values |
| 81 | if np.isnan(chunk_vec).any() or np.isinf(chunk_vec).any(): |
| 82 | print(f"⚠️ Warning: Invalid values detected in late chunk embedding for span ({start_char}, {end_char})") |
| 83 | # Replace invalid values with zeros |
| 84 | chunk_vec = np.nan_to_num(chunk_vec, nan=0.0, posinf=0.0, neginf=0.0) |
| 85 | print(f"🔄 Replaced invalid values with zeros") |
| 86 | |
| 87 | vectors.append(chunk_vec) |
| 88 | return vectors |
no outgoing calls
no test coverage detected