from__future__importannotationsimportloggingimportrefromtypingimportAny,Dict,List,Optional,Sequence,Tuplefromlangchain_core.callbacksimport(CallbackManagerForChainRun,)fromlangchain_core.language_modelsimportBaseLanguageModelfromlangchain_core.messagesimportAIMessagefromlangchain_core.output_parsersimportStrOutputParserfromlangchain_core.promptsimportBasePromptTemplatefromlangchain_core.retrieversimportBaseRetrieverfromlangchain_core.runnablesimportRunnablefrompydanticimportFieldfromlangchain.chains.baseimportChainfromlangchain.chains.flare.promptsimport(PROMPT,QUESTION_GENERATOR_PROMPT,FinishedOutputParser,)fromlangchain.chains.llmimportLLMChainlogger=logging.getLogger(__name__)def_extract_tokens_and_log_probs(response:AIMessage)->Tuple[List[str],List[float]]:"""Extract tokens and log probabilities from chat model response."""tokens=[]log_probs=[]fortokeninresponse.response_metadata["logprobs"]["content"]:tokens.append(token["token"])log_probs.append(token["logprob"])returntokens,log_probs
[docs]classQuestionGeneratorChain(LLMChain):"""Chain that generates questions from uncertain spans."""prompt:BasePromptTemplate=QUESTION_GENERATOR_PROMPT"""Prompt template for the chain."""@classmethoddefis_lc_serializable(cls)->bool:returnFalse@propertydefinput_keys(self)->List[str]:"""Input keys for the chain."""return["user_input","context","response"]
def_low_confidence_spans(tokens:Sequence[str],log_probs:Sequence[float],min_prob:float,min_token_gap:int,num_pad_tokens:int,)->List[str]:try:importnumpyasnp_low_idx=np.where(np.exp(log_probs)<min_prob)[0]exceptImportError:logger.warning("NumPy not found in the current Python environment. FlareChain will use a ""pure Python implementation for internal calculations, which may ""significantly impact performance, especially for large datasets. For ""optimal speed and efficiency, consider installing NumPy: pip install numpy")importmath_low_idx=[# type: ignore[assignment]idxforidx,log_probinenumerate(log_probs)ifmath.exp(log_prob)<min_prob]low_idx=[iforiin_low_idxifre.search(r"\w",tokens[i])]iflen(low_idx)==0:return[]spans=[[low_idx[0],low_idx[0]+num_pad_tokens+1]]fori,idxinenumerate(low_idx[1:]):end=idx+num_pad_tokens+1ifidx-low_idx[i]<min_token_gap:spans[-1][1]=endelse:spans.append([idx,end])return["".join(tokens[start:end])forstart,endinspans]
[docs]classFlareChain(Chain):"""Chain that combines a retriever, a question generator, and a response generator. See [Active Retrieval Augmented Generation](https://arxiv.org/abs/2305.06983) paper. """question_generator_chain:Runnable"""Chain that generates questions from uncertain spans."""response_chain:Runnable"""Chain that generates responses from user input and context."""output_parser:FinishedOutputParser=Field(default_factory=FinishedOutputParser)"""Parser that determines whether the chain is finished."""retriever:BaseRetriever"""Retriever that retrieves relevant documents from a user input."""min_prob:float=0.2"""Minimum probability for a token to be considered low confidence."""min_token_gap:int=5"""Minimum number of tokens between two low confidence spans."""num_pad_tokens:int=2"""Number of tokens to pad around a low confidence span."""max_iter:int=10"""Maximum number of iterations."""start_with_retrieval:bool=True"""Whether to start with retrieval."""@propertydefinput_keys(self)->List[str]:"""Input keys for the chain."""return["user_input"]@propertydefoutput_keys(self)->List[str]:"""Output keys for the chain."""return["response"]def_do_generation(self,questions:List[str],user_input:str,response:str,_run_manager:CallbackManagerForChainRun,)->Tuple[str,bool]:callbacks=_run_manager.get_child()docs=[]forquestioninquestions:docs.extend(self.retriever.invoke(question))context="\n\n".join(d.page_contentfordindocs)result=self.response_chain.invoke({"user_input":user_input,"context":context,"response":response,},{"callbacks":callbacks},)ifisinstance(result,AIMessage):result=result.contentmarginal,finished=self.output_parser.parse(result)returnmarginal,finisheddef_do_retrieval(self,low_confidence_spans:List[str],_run_manager:CallbackManagerForChainRun,user_input:str,response:str,initial_response:str,)->Tuple[str,bool]:question_gen_inputs=[{"user_input":user_input,"current_response":initial_response,"uncertain_span":span,}forspaninlow_confidence_spans]callbacks=_run_manager.get_child()ifisinstance(self.question_generator_chain,LLMChain):question_gen_outputs=self.question_generator_chain.apply(question_gen_inputs,callbacks=callbacks)questions=[output[self.question_generator_chain.output_keys[0]]foroutputinquestion_gen_outputs]else:questions=self.question_generator_chain.batch(question_gen_inputs,config={"callbacks":callbacks})_run_manager.on_text(f"Generated Questions: {questions}",color="yellow",end="\n")returnself._do_generation(questions,user_input,response,_run_manager)def_call(self,inputs:Dict[str,Any],run_manager:Optional[CallbackManagerForChainRun]=None,)->Dict[str,Any]:_run_manager=run_managerorCallbackManagerForChainRun.get_noop_manager()user_input=inputs[self.input_keys[0]]response=""foriinrange(self.max_iter):_run_manager.on_text(f"Current Response: {response}",color="blue",end="\n")_input={"user_input":user_input,"context":"","response":response}tokens,log_probs=_extract_tokens_and_log_probs(self.response_chain.invoke(_input,{"callbacks":_run_manager.get_child()}))low_confidence_spans=_low_confidence_spans(tokens,log_probs,self.min_prob,self.min_token_gap,self.num_pad_tokens,)initial_response=response.strip()+" "+"".join(tokens)ifnotlow_confidence_spans:response=initial_responsefinal_response,finished=self.output_parser.parse(response)iffinished:return{self.output_keys[0]:final_response}continuemarginal,finished=self._do_retrieval(low_confidence_spans,_run_manager,user_input,response,initial_response,)response=response.strip()+" "+marginaliffinished:breakreturn{self.output_keys[0]:response}
[docs]@classmethoddeffrom_llm(cls,llm:BaseLanguageModel,max_generation_len:int=32,**kwargs:Any)->FlareChain:"""Creates a FlareChain from a language model. Args: llm: Language model to use. max_generation_len: Maximum length of the generated response. kwargs: Additional arguments to pass to the constructor. Returns: FlareChain class with the given language model. """try:fromlangchain_openaiimportChatOpenAIexceptImportError:raiseImportError("OpenAI is required for FlareChain. ""Please install langchain-openai.""pip install langchain-openai")llm=ChatOpenAI(max_completion_tokens=max_generation_len,logprobs=True,temperature=0)response_chain=PROMPT|llmquestion_gen_chain=QUESTION_GENERATOR_PROMPT|llm|StrOutputParser()returncls(question_generator_chain=question_gen_chain,response_chain=response_chain,**kwargs,)