[docs]classSearchQueries(BaseModel):"""Search queries to research for the user's goal."""queries:List[str]=Field(...,description="List of search queries to look up on Google")
DEFAULT_LLAMA_SEARCH_PROMPT=PromptTemplate(input_variables=["question"],template="""<<SYS>> \n You are an assistant tasked with improving Google search \results. \n <</SYS>> \n\n [INST] Generate THREE Google search queries that \are similar to this question. The output should be a numbered list of questions \and each should have a question mark at the end: \n\n{question} [/INST]""",)DEFAULT_SEARCH_PROMPT=PromptTemplate(input_variables=["question"],template="""You are an assistant tasked with improving Google search \results. Generate THREE Google search queries that are similar to \this question. The output should be a numbered list of questions and each \should have a question mark at the end: {question}""",)
[docs]classQuestionListOutputParser(BaseOutputParser[List[str]]):"""Output parser for a list of numbered questions."""
[docs]classWebResearchRetriever(BaseRetriever):"""`Google Search API` retriever."""# Inputsvectorstore:VectorStore=Field(...,description="Vector store for storing web pages")llm_chain:LLMChainsearch:GoogleSearchAPIWrapper=Field(...,description="Google Search API Wrapper")num_search_results:int=Field(1,description="Number of pages per Google search")text_splitter:TextSplitter=Field(RecursiveCharacterTextSplitter(chunk_size=1500,chunk_overlap=50),description="Text splitter for splitting web pages into chunks",)url_database:List[str]=Field(default_factory=list,description="List of processed URLs")trust_env:bool=Field(False,description="Whether to use the http_proxy/https_proxy env variables or ""check .netrc for proxy configuration",)allow_dangerous_requests:bool=False"""A flag to force users to acknowledge the risks of SSRF attacks when using this retriever. Users should set this flag to `True` if they have taken the necessary precautions to prevent SSRF attacks when using this retriever. For example, users can run the requests through a properly configured proxy and prevent the crawler from accidentally crawling internal resources. """def__init__(self,**kwargs:Any)->None:"""Initialize the retriever."""allow_dangerous_requests=kwargs.get("allow_dangerous_requests",False)ifnotallow_dangerous_requests:raiseValueError("WebResearchRetriever crawls URLs surfaced through ""the provided search engine. It is possible that some of those URLs ""will end up pointing to machines residing on an internal network, ""leading""to an SSRF (Server-Side Request Forgery) attack. ""To protect yourself against that risk, you can run the requests ""through a proxy and prevent the crawler from accidentally crawling ""internal resources.""If've taken the necessary precautions, you can set ""`allow_dangerous_requests` to `True`.")super().__init__(**kwargs)
[docs]@classmethoddeffrom_llm(cls,vectorstore:VectorStore,llm:BaseLLM,search:GoogleSearchAPIWrapper,prompt:Optional[BasePromptTemplate]=None,num_search_results:int=1,text_splitter:RecursiveCharacterTextSplitter=RecursiveCharacterTextSplitter(chunk_size=1500,chunk_overlap=150),trust_env:bool=False,allow_dangerous_requests:bool=False,)->"WebResearchRetriever":"""Initialize from llm using default template. Args: vectorstore: Vector store for storing web pages llm: llm for search question generation search: GoogleSearchAPIWrapper prompt: prompt to generating search questions num_search_results: Number of pages per Google search text_splitter: Text splitter for splitting web pages into chunks trust_env: Whether to use the http_proxy/https_proxy env variables or check .netrc for proxy configuration allow_dangerous_requests: A flag to force users to acknowledge the risks of SSRF attacks when using this retriever Returns: WebResearchRetriever """ifnotprompt:QUESTION_PROMPT_SELECTOR=ConditionalPromptSelector(default_prompt=DEFAULT_SEARCH_PROMPT,conditionals=[(lambdallm:isinstance(llm,LlamaCpp),DEFAULT_LLAMA_SEARCH_PROMPT)],)prompt=QUESTION_PROMPT_SELECTOR.get_prompt(llm)# Use chat model promptllm_chain=LLMChain(llm=llm,prompt=prompt,output_parser=QuestionListOutputParser(),)returncls(vectorstore=vectorstore,llm_chain=llm_chain,search=search,num_search_results=num_search_results,text_splitter=text_splitter,trust_env=trust_env,allow_dangerous_requests=allow_dangerous_requests,)
[docs]defclean_search_query(self,query:str)->str:# Some search tools (e.g., Google) will# fail to return results if query has a# leading digit: 1. "LangCh..."# Check if the first character is a digitifquery[0].isdigit():# Find the position of the first quotefirst_quote_pos=query.find('"')iffirst_quote_pos!=-1:# Extract the part of the string after the quotequery=query[first_quote_pos+1:]# Remove the trailing quote if presentifquery.endswith('"'):query=query[:-1]returnquery.strip()
[docs]defsearch_tool(self,query:str,num_search_results:int=1)->List[dict]:"""Returns num_search_results pages per Google search."""query_clean=self.clean_search_query(query)result=self.search.results(query_clean,num_search_results)returnresult
def_get_relevant_documents(self,query:str,*,run_manager:CallbackManagerForRetrieverRun,)->List[Document]:"""Search Google for documents related to the query input. Args: query: user query Returns: Relevant documents from all various urls. """# Get search questionslogger.info("Generating questions for Google Search ...")result=self.llm_chain({"question":query})logger.info(f"Questions for Google Search (raw): {result}")questions=result["text"]logger.info(f"Questions for Google Search: {questions}")# Get urlslogger.info("Searching for relevant urls...")urls_to_look=[]forqueryinquestions:# Google searchsearch_results=self.search_tool(query,self.num_search_results)logger.info("Searching for relevant urls...")logger.info(f"Search results: {search_results}")forresinsearch_results:ifres.get("link",None):urls_to_look.append(res["link"])# Relevant urlsurls=set(urls_to_look)# Check for any new urls that we have not processednew_urls=list(urls.difference(self.url_database))logger.info(f"New URLs to load: {new_urls}")# Load, split, and add new urls to vectorstoreifnew_urls:loader=AsyncHtmlLoader(new_urls,ignore_load_errors=True,trust_env=self.trust_env)html2text=Html2TextTransformer()logger.info("Indexing new urls...")docs=loader.load()docs=list(html2text.transform_documents(docs))docs=self.text_splitter.split_documents(docs)self.vectorstore.add_documents(docs)self.url_database.extend(new_urls)# Search for relevant splits# TODO: make this asynclogger.info("Grabbing most relevant splits from urls...")docs=[]forqueryinquestions:docs.extend(self.vectorstore.similarity_search(query))# Get unique docsunique_documents_dict={(doc.page_content,tuple(sorted(doc.metadata.items()))):docfordocindocs}unique_documents=list(unique_documents_dict.values())returnunique_documentsasyncdef_aget_relevant_documents(self,query:str,*,run_manager:AsyncCallbackManagerForRetrieverRun,)->List[Document]:raiseNotImplementedError