Source code for langchain_community.document_compressors.rankllm_rerank
from__future__importannotationsfromcopyimportdeepcopyfromenumimportEnumfromtypingimportTYPE_CHECKING,Any,Dict,Optional,Sequencefromlangchain.retrievers.document_compressors.baseimportBaseDocumentCompressorfromlangchain_core.callbacks.managerimportCallbacksfromlangchain_core.documentsimportDocumentfromlangchain_core.pydantic_v1importField,PrivateAttr,root_validatorfromlangchain_core.utilsimportget_from_dict_or_envifTYPE_CHECKING:fromrank_llm.dataimportCandidate,Query,Requestelse:# Avoid pydantic annotation issues when actually instantiating# while keeping this import optionaltry:fromrank_llm.dataimportCandidate,Query,RequestexceptImportError:pass
[docs]classRankLLMRerank(BaseDocumentCompressor):"""Document compressor using Flashrank interface."""client:Any=None"""RankLLM client to use for compressing documents"""top_n:int=Field(default=3)"""Top N documents to return."""model:str=Field(default="zephyr")"""Name of model to use for reranking."""step_size:int=Field(default=10)"""Step size for moving sliding window."""gpt_model:str=Field(default="gpt-3.5-turbo")"""OpenAI model name."""_retriever:Any=PrivateAttr()classConfig:arbitrary_types_allowed=Trueextra="forbid"@root_validator(pre=True)defvalidate_environment(cls,values:Dict)->Dict:"""Validate python package exists in environment."""ifnotvalues.get("client"):client_name=values.get("model","zephyr")try:model_enum=ModelType(client_name.lower())exceptValueError:raiseValueError("Unsupported model type. Please use 'vicuna', 'zephyr', or 'gpt'.")try:ifmodel_enum==ModelType.VICUNA:fromrank_llm.rerank.vicuna_rerankerimportVicunaRerankervalues["client"]=VicunaReranker()elifmodel_enum==ModelType.ZEPHYR:fromrank_llm.rerank.zephyr_rerankerimportZephyrRerankervalues["client"]=ZephyrReranker()elifmodel_enum==ModelType.GPT:fromrank_llm.rerank.rank_gptimportSafeOpenaifromrank_llm.rerank.rerankerimportRerankeropenai_api_key=get_from_dict_or_env(values,"open_api_key","OPENAI_API_KEY")agent=SafeOpenai(model=values["gpt_model"],context_size=4096,keys=openai_api_key,)values["client"]=Reranker(agent)exceptImportError:raiseImportError("Could not import rank_llm python package. ""Please install it with `pip install rank_llm`.")returnvalues