Initialize the Feast RAG retriever. Args: question_encoder_tokenizer: Tokenizer for encoding questions question_encoder: Model for encoding questions generator_tokenizer: Tokenizer for the generator model generator_model: The generator model
(
self,
question_encoder_tokenizer: PreTrainedTokenizer,
question_encoder: PreTrainedModel,
generator_tokenizer: PreTrainedTokenizer,
generator_model: PreTrainedModel,
feast_repo_path: str,
feature_view: FeatureView,
features: list[str],
search_type: str,
config: dict[str, Any],
index: FeastIndex,
format_document: Optional[Callable[[dict[str, Any]], str]] = None,
id_field: str = "",
text_field: str = "",
**kwargs,
)
| 39 | VALID_SEARCH_TYPES = {"text", "vector", "hybrid"} |
| 40 | |
| 41 | def __init__( |
| 42 | self, |
| 43 | question_encoder_tokenizer: PreTrainedTokenizer, |
| 44 | question_encoder: PreTrainedModel, |
| 45 | generator_tokenizer: PreTrainedTokenizer, |
| 46 | generator_model: PreTrainedModel, |
| 47 | feast_repo_path: str, |
| 48 | feature_view: FeatureView, |
| 49 | features: list[str], |
| 50 | search_type: str, |
| 51 | config: dict[str, Any], |
| 52 | index: FeastIndex, |
| 53 | format_document: Optional[Callable[[dict[str, Any]], str]] = None, |
| 54 | id_field: str = "", |
| 55 | text_field: str = "", |
| 56 | **kwargs, |
| 57 | ): |
| 58 | """Initialize the Feast RAG retriever. |
| 59 | |
| 60 | Args: |
| 61 | question_encoder_tokenizer: Tokenizer for encoding questions |
| 62 | question_encoder: Model for encoding questions |
| 63 | generator_tokenizer: Tokenizer for the generator model |
| 64 | generator_model: The generator model |
| 65 | feast_repo_path: Path to the Feast repository |
| 66 | feature_view: Feast FeatureView containing the document data |
| 67 | features: List of feature names to use from the feature view |
| 68 | search_type: Type of search to perform (text, vector, or hybrid) |
| 69 | config: Configuration for the retriever |
| 70 | index: Index instance (must be FeastIndex) |
| 71 | format_document: Optional function to format retrieved documents |
| 72 | id_field: Field to use as document ID |
| 73 | text_field: Field to use as text field name |
| 74 | **kwargs: Additional arguments passed to RagRetriever |
| 75 | """ |
| 76 | if search_type.lower() not in self.VALID_SEARCH_TYPES: |
| 77 | raise ValueError( |
| 78 | f"Unsupported search_type {search_type}. " |
| 79 | f"Must be one of: {self.VALID_SEARCH_TYPES}" |
| 80 | ) |
| 81 | |
| 82 | # move to gpu if available |
| 83 | torch = get_torch() |
| 84 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 85 | self.question_encoder = question_encoder.to(self.device) # type: ignore |
| 86 | self.generator_model = generator_model.to(self.device) # type: ignore |
| 87 | |
| 88 | self.question_encoder_tokenizer = question_encoder_tokenizer |
| 89 | self.generator_tokenizer = generator_tokenizer |
| 90 | |
| 91 | super().__init__( |
| 92 | config=config, |
| 93 | question_encoder_tokenizer=self.question_encoder_tokenizer, |
| 94 | generator_tokenizer=self.generator_tokenizer, |
| 95 | index=index, |
| 96 | init_retrieval=False, |
| 97 | **kwargs, |
| 98 | ) |