from __future__ import annotations
import re
from typing import Any, Dict, List, Optional, Sequence, Tuple
import numpy as np
from langchain_core.callbacks import (
CallbackManagerForChainRun,
)
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import BasePromptTemplate
from langchain_core.retrievers import BaseRetriever
from langchain_core.runnables import Runnable
from pydantic import Field
from langchain.chains.base import Chain
from langchain.chains.flare.prompts import (
PROMPT,
QUESTION_GENERATOR_PROMPT,
FinishedOutputParser,
)
from langchain.chains.llm import LLMChain
def _extract_tokens_and_log_probs(response: AIMessage) -> Tuple[List[str], List[float]]:
"""Extract tokens and log probabilities from chat model response."""
tokens = []
log_probs = []
for token in response.response_metadata["logprobs"]["content"]:
tokens.append(token["token"])
log_probs.append(token["logprob"])
return tokens, log_probs
[docs]
class QuestionGeneratorChain(LLMChain):
"""Chain that generates questions from uncertain spans."""
prompt: BasePromptTemplate = QUESTION_GENERATOR_PROMPT
"""Prompt template for the chain."""
@classmethod
def is_lc_serializable(cls) -> bool:
return False
@property
def input_keys(self) -> List[str]:
"""Input keys for the chain."""
return ["user_input", "context", "response"]
def _low_confidence_spans(
tokens: Sequence[str],
log_probs: Sequence[float],
min_prob: float,
min_token_gap: int,
num_pad_tokens: int,
) -> List[str]:
_low_idx = np.where(np.exp(log_probs) < min_prob)[0]
low_idx = [i for i in _low_idx if re.search(r"\w", tokens[i])]
if len(low_idx) == 0:
return []
spans = [[low_idx[0], low_idx[0] + num_pad_tokens + 1]]
for i, idx in enumerate(low_idx[1:]):
end = idx + num_pad_tokens + 1
if idx - low_idx[i] < min_token_gap:
spans[-1][1] = end
else:
spans.append([idx, end])
return ["".join(tokens[start:end]) for start, end in spans]
[docs]
class FlareChain(Chain):
"""Chain that combines a retriever, a question generator,
and a response generator.
See [Active Retrieval Augmented Generation](https://arxiv.org/abs/2305.06983) paper.
"""
question_generator_chain: Runnable
"""Chain that generates questions from uncertain spans."""
response_chain: Runnable
"""Chain that generates responses from user input and context."""
output_parser: FinishedOutputParser = Field(default_factory=FinishedOutputParser)
"""Parser that determines whether the chain is finished."""
retriever: BaseRetriever
"""Retriever that retrieves relevant documents from a user input."""
min_prob: float = 0.2
"""Minimum probability for a token to be considered low confidence."""
min_token_gap: int = 5
"""Minimum number of tokens between two low confidence spans."""
num_pad_tokens: int = 2
"""Number of tokens to pad around a low confidence span."""
max_iter: int = 10
"""Maximum number of iterations."""
start_with_retrieval: bool = True
"""Whether to start with retrieval."""
@property
def input_keys(self) -> List[str]:
"""Input keys for the chain."""
return ["user_input"]
@property
def output_keys(self) -> List[str]:
"""Output keys for the chain."""
return ["response"]
def _do_generation(
self,
questions: List[str],
user_input: str,
response: str,
_run_manager: CallbackManagerForChainRun,
) -> Tuple[str, bool]:
callbacks = _run_manager.get_child()
docs = []
for question in questions:
docs.extend(self.retriever.invoke(question))
context = "\n\n".join(d.page_content for d in docs)
result = self.response_chain.invoke(
{
"user_input": user_input,
"context": context,
"response": response,
},
{"callbacks": callbacks},
)
if isinstance(result, AIMessage):
result = result.content
marginal, finished = self.output_parser.parse(result)
return marginal, finished
def _do_retrieval(
self,
low_confidence_spans: List[str],
_run_manager: CallbackManagerForChainRun,
user_input: str,
response: str,
initial_response: str,
) -> Tuple[str, bool]:
question_gen_inputs = [
{
"user_input": user_input,
"current_response": initial_response,
"uncertain_span": span,
}
for span in low_confidence_spans
]
callbacks = _run_manager.get_child()
if isinstance(self.question_generator_chain, LLMChain):
question_gen_outputs = self.question_generator_chain.apply(
question_gen_inputs, callbacks=callbacks
)
questions = [
output[self.question_generator_chain.output_keys[0]]
for output in question_gen_outputs
]
else:
questions = self.question_generator_chain.batch(
question_gen_inputs, config={"callbacks": callbacks}
)
_run_manager.on_text(
f"Generated Questions: {questions}", color="yellow", end="\n"
)
return self._do_generation(questions, user_input, response, _run_manager)
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
user_input = inputs[self.input_keys[0]]
response = ""
for i in range(self.max_iter):
_run_manager.on_text(
f"Current Response: {response}", color="blue", end="\n"
)
_input = {"user_input": user_input, "context": "", "response": response}
tokens, log_probs = _extract_tokens_and_log_probs(
self.response_chain.invoke(
_input, {"callbacks": _run_manager.get_child()}
)
)
low_confidence_spans = _low_confidence_spans(
tokens,
log_probs,
self.min_prob,
self.min_token_gap,
self.num_pad_tokens,
)
initial_response = response.strip() + " " + "".join(tokens)
if not low_confidence_spans:
response = initial_response
final_response, finished = self.output_parser.parse(response)
if finished:
return {self.output_keys[0]: final_response}
continue
marginal, finished = self._do_retrieval(
low_confidence_spans,
_run_manager,
user_input,
response,
initial_response,
)
response = response.strip() + " " + marginal
if finished:
break
return {self.output_keys[0]: response}
[docs]
@classmethod
def from_llm(
cls, llm: BaseLanguageModel, max_generation_len: int = 32, **kwargs: Any
) -> FlareChain:
"""Creates a FlareChain from a language model.
Args:
llm: Language model to use.
max_generation_len: Maximum length of the generated response.
kwargs: Additional arguments to pass to the constructor.
Returns:
FlareChain class with the given language model.
"""
try:
from langchain_openai import ChatOpenAI
except ImportError:
raise ImportError(
"OpenAI is required for FlareChain. "
"Please install langchain-openai."
"pip install langchain-openai"
)
llm = ChatOpenAI(max_tokens=max_generation_len, logprobs=True, temperature=0)
response_chain = PROMPT | llm
question_gen_chain = QUESTION_GENERATOR_PROMPT | llm | StrOutputParser()
return cls(
question_generator_chain=question_gen_chain,
response_chain=response_chain,
**kwargs,
)