Source code for langchain_cohere.rag_retrievers

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, List, Optional

from langchain_core.callbacks import (
    AsyncCallbackManagerForRetrieverRun,
    CallbackManagerForRetrieverRun,
)
from langchain_core.documents import Document
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import HumanMessage
from langchain_core.retrievers import BaseRetriever
from pydantic import ConfigDict, Field

if TYPE_CHECKING:
    from langchain_core.messages import BaseMessage


def _get_docs(response: Any) -> List[Document]:
    docs = []
    if (
        "documents" in response.generation_info
        and response.generation_info["documents"]
    ):
        for doc in response.generation_info["documents"]:
            content = doc.get("snippet", None) or doc.get("text", None)
            if content is not None:
                docs.append(Document(page_content=content, metadata=doc))

    docs.append(
        Document(
            page_content=response.message.content,
            metadata={
                "type": "model_response",
                "citations": response.generation_info["citations"],
                "search_results": response.generation_info["search_results"],
                "search_queries": response.generation_info["search_queries"],
                "token_count": response.generation_info["token_count"],
            },
        )
    )
    return docs


[docs] class CohereRagRetriever(BaseRetriever): """Cohere Chat API with RAG.""" connectors: List[Dict] = Field( default_factory=lambda: [{"id": "web-search"}], deprecated="The 'connectors' parameter is deprecated as of version 0.3.3. Please use the 'tools' parameter instead.", # noqa: E501 ) """ When specified, the model's reply will be enriched with information found by querying each of the connectors (RAG). These will be returned as langchain documents. Currently only accepts {"id": "web-search"}. """ llm: BaseChatModel """Cohere ChatModel to use.""" model_config = ConfigDict( arbitrary_types_allowed=True, ) def _get_relevant_documents( self, query: str, *, run_manager: CallbackManagerForRetrieverRun, documents: Optional[List[Dict[str, str]]] = None, **kwargs: Any, ) -> List[Document]: messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]] res = self.llm.generate( messages, connectors=self.connectors if documents is None else None, documents=documents, callbacks=run_manager.get_child(), **kwargs, ).generations[0][0] return _get_docs(res) async def _aget_relevant_documents( self, query: str, *, run_manager: AsyncCallbackManagerForRetrieverRun, documents: Optional[List[Dict[str, str]]] = None, **kwargs: Any, ) -> List[Document]: messages: List[List[BaseMessage]] = [[HumanMessage(content=query)]] res = ( await self.llm.agenerate( messages, connectors=self.connectors if documents is None else None, documents=documents, callbacks=run_manager.get_child(), **kwargs, ) ).generations[0][0] return _get_docs(res)