Source code for langchain_community.document_compressors.rankllm_rerank
from__future__importannotationsfromcopyimportdeepcopyfromenumimportEnumfromimportlib.metadataimportversionfromtypingimportTYPE_CHECKING,Any,Dict,Optional,Sequencefromlangchain.retrievers.document_compressors.baseimportBaseDocumentCompressorfromlangchain_core.callbacks.managerimportCallbacksfromlangchain_core.documentsimportDocumentfromlangchain_core.utilsimportget_from_dict_or_envfrompackaging.versionimportVersionfrompydanticimportConfigDict,Field,PrivateAttr,model_validatorifTYPE_CHECKING:fromrank_llm.dataimportCandidate,Query,Requestelse:# Avoid pydantic annotation issues when actually instantiating# while keeping this import optionaltry:fromrank_llm.dataimportCandidate,Query,RequestexceptImportError:pass
[docs]classRankLLMRerank(BaseDocumentCompressor):"""Document compressor using Flashrank interface."""client:Any=None"""RankLLM client to use for compressing documents"""top_n:int=Field(default=3)"""Top N documents to return."""model:str=Field(default="zephyr")"""Name of model to use for reranking."""step_size:int=Field(default=10)"""Step size for moving sliding window."""gpt_model:str=Field(default="gpt-3.5-turbo")"""OpenAI model name."""_retriever:Any=PrivateAttr()model_config=ConfigDict(arbitrary_types_allowed=True,extra="forbid",)@model_validator(mode="before")@classmethoddefvalidate_environment(cls,values:Dict)->Any:"""Validate python package exists in environment."""ifnotvalues.get("client"):client_name=values.get("model","zephyr")is_pre_rank_llm_revamp=Version(version=version("rank_llm"))<=Version("0.12.8")try:model_enum=ModelType(client_name.lower())exceptValueError:raiseValueError("Unsupported model type. Please use 'vicuna', 'zephyr', or 'gpt'.")try:ifmodel_enum==ModelType.VICUNA:ifis_pre_rank_llm_revamp:fromrank_llm.rerank.vicuna_rerankerimportVicunaRerankerelse:fromrank_llm.rerank.listwise.vicuna_rerankerimport(VicunaReranker,)values["client"]=VicunaReranker()elifmodel_enum==ModelType.ZEPHYR:ifis_pre_rank_llm_revamp:fromrank_llm.rerank.zephyr_rerankerimportZephyrRerankerelse:fromrank_llm.rerank.listwise.zephyr_rerankerimport(ZephyrReranker,)values["client"]=ZephyrReranker()elifmodel_enum==ModelType.GPT:ifis_pre_rank_llm_revamp:fromrank_llm.rerank.rank_gptimportSafeOpenaielse:fromrank_llm.rerank.listwise.rank_gptimportSafeOpenaifromrank_llm.rerank.rerankerimportRerankeropenai_api_key=get_from_dict_or_env(values,"open_api_key","OPENAI_API_KEY")agent=SafeOpenai(model=values["gpt_model"],context_size=4096,keys=openai_api_key,)values["client"]=Reranker(agent)exceptImportError:raiseImportError("Could not import rank_llm python package. ""Please install it with `pip install rank_llm`.")returnvalues
[docs]defcompress_documents(self,documents:Sequence[Document],query:str,callbacks:Optional[Callbacks]=None,)->Sequence[Document]:request=Request(query=Query(text=query,qid=1),candidates=[Candidate(doc={"text":doc.page_content},docid=index,score=1)forindex,docinenumerate(documents)],)rerank_results=self.client.rerank(request,rank_end=len(documents),window_size=min(20,len(documents)),step=10,)final_results=[]ifhasattr(rerank_results,"candidates"):# Old API formatforresinrerank_results.candidates:doc=documents[int(res.docid)]doc_copy=Document(doc.page_content,metadata=deepcopy(doc.metadata))final_results.append(doc_copy)else:forresinrerank_results:doc=documents[int(res.docid)]doc_copy=Document(doc.page_content,metadata=deepcopy(doc.metadata))final_results.append(doc_copy)returnfinal_results[:self.top_n]