[docs]defcreate_index(contexts:List[str],embeddings:Embeddings)->np.ndarray:""" Create an index of embeddings for a list of contexts. Args: contexts: List of contexts to embed. embeddings: Embeddings model to use. Returns: Index of embeddings. """withconcurrent.futures.ThreadPoolExecutor()asexecutor:returnnp.array(list(executor.map(embeddings.embed_query,contexts)))
[docs]classSVMRetriever(BaseRetriever):"""`SVM` retriever. Largely based on https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb """embeddings:Embeddings"""Embeddings model to use."""index:Any"""Index of embeddings."""texts:List[str]"""List of texts to index."""metadatas:Optional[List[dict]]=None"""List of metadatas corresponding with each text."""k:int=4"""Number of results to return."""relevancy_threshold:Optional[float]=None"""Threshold for relevancy."""classConfig:arbitrary_types_allowed=True
def_get_relevant_documents(self,query:str,*,run_manager:CallbackManagerForRetrieverRun)->List[Document]:try:fromsklearnimportsvmexceptImportError:raiseImportError("Could not import scikit-learn, please install with `pip install ""scikit-learn`.")query_embeds=np.array(self.embeddings.embed_query(query))x=np.concatenate([query_embeds[None,...],self.index])y=np.zeros(x.shape[0])y[0]=1clf=svm.LinearSVC(class_weight="balanced",verbose=False,max_iter=10000,tol=1e-6,C=0.1)clf.fit(x,y)similarities=clf.decision_function(x)sorted_ix=np.argsort(-similarities)# svm.LinearSVC in scikit-learn is non-deterministic.# if a text is the same as a query, there is no guarantee# the query will be in the first index.# this performs a simple swap, this works because anything# left of the 0 should be equivalent.zero_index=np.where(sorted_ix==0)[0][0]ifzero_index!=0:sorted_ix[0],sorted_ix[zero_index]=sorted_ix[zero_index],sorted_ix[0]denominator=np.max(similarities)-np.min(similarities)+1e-6normalized_similarities=(similarities-np.min(similarities))/denominatortop_k_results=[]forrowinsorted_ix[1:self.k+1]:if(self.relevancy_thresholdisNoneornormalized_similarities[row]>=self.relevancy_threshold):metadata=self.metadatas[row-1]ifself.metadataselse{}doc=Document(page_content=self.texts[row-1],metadata=metadata)top_k_results.append(doc)returntop_k_results