[docs]classTFIDFRetriever(BaseRetriever):"""`TF-IDF` retriever. Largely based on https://github.com/asvskartheek/Text-Retrieval/blob/master/TF-IDF%20Search%20Engine%20(SKLEARN).ipynb """vectorizer:Any=None"""TF-IDF vectorizer."""docs:List[Document]"""Documents."""tfidf_array:Any=None"""TF-IDF array."""k:int=4"""Number of documents to return."""model_config=ConfigDict(arbitrary_types_allowed=True,)
[docs]@classmethoddeffrom_texts(cls,texts:Iterable[str],metadatas:Optional[Iterable[dict]]=None,tfidf_params:Optional[Dict[str,Any]]=None,**kwargs:Any,)->TFIDFRetriever:try:fromsklearn.feature_extraction.textimportTfidfVectorizerexceptImportError:raiseImportError("Could not import scikit-learn, please install with `pip install ""scikit-learn`.")tfidf_params=tfidf_paramsor{}vectorizer=TfidfVectorizer(**tfidf_params)tfidf_array=vectorizer.fit_transform(texts)metadatas=metadatasor({}for_intexts)docs=[Document(page_content=t,metadata=m)fort,minzip(texts,metadatas)]returncls(vectorizer=vectorizer,docs=docs,tfidf_array=tfidf_array,**kwargs)
def_get_relevant_documents(self,query:str,*,run_manager:CallbackManagerForRetrieverRun)->List[Document]:fromsklearn.metrics.pairwiseimportcosine_similarityquery_vec=self.vectorizer.transform([query])# Ip -- (n_docs,x), Op -- (n_docs,n_Feats)results=cosine_similarity(self.tfidf_array,query_vec).reshape((-1,))# Op -- (n_docs,1) -- Cosine Sim with each docreturn_docs=[self.docs[i]foriinresults.argsort()[-self.k:][::-1]]returnreturn_docs
[docs]defsave_local(self,folder_path:str,file_name:str="tfidf_vectorizer",)->None:try:importjoblibexceptImportError:raiseImportError("Could not import joblib, please install with `pip install joblib`.")path=Path(folder_path)path.mkdir(exist_ok=True,parents=True)# Save vectorizer with joblib dump.joblib.dump(self.vectorizer,path/f"{file_name}.joblib")# Save docs and tfidf array as pickle.withopen(path/f"{file_name}.pkl","wb")asf:pickle.dump((self.docs,self.tfidf_array),f)
[docs]@classmethoddefload_local(cls,folder_path:str,*,allow_dangerous_deserialization:bool=False,file_name:str="tfidf_vectorizer",)->TFIDFRetriever:"""Load the retriever from local storage. Args: folder_path: Folder path to load from. allow_dangerous_deserialization: Whether to allow dangerous deserialization. Defaults to False. The deserialization relies on .joblib and .pkl files, which can be modified to deliver a malicious payload that results in execution of arbitrary code on your machine. You will need to set this to `True` to use deserialization. If you do this, make sure you trust the source of the file. file_name: File name to load from. Defaults to "tfidf_vectorizer". Returns: TFIDFRetriever: Loaded retriever. """try:importjoblibexceptImportError:raiseImportError("Could not import joblib, please install with `pip install joblib`.")ifnotallow_dangerous_deserialization:raiseValueError("The de-serialization of this retriever is based on .joblib and "".pkl files.""Such files can be modified to deliver a malicious payload that ""results in execution of arbitrary code on your machine.""You will need to set `allow_dangerous_deserialization` to `True` to ""load this retriever. If you do this, make sure you trust the source ""of the file, and you are responsible for validating the file ""came from a trusted source.")path=Path(folder_path)# Load vectorizer with joblib load.vectorizer=joblib.load(path/f"{file_name}.joblib")# Load docs and tfidf array as pickle.withopen(path/f"{file_name}.pkl","rb")asf:# This code path can only be triggered if the user# passed allow_dangerous_deserialization=Truedocs,tfidf_array=pickle.load(f)# ignore[pickle]: explicit-opt-inreturncls(vectorizer=vectorizer,docs=docs,tfidf_array=tfidf_array)