from__future__importannotationsimportjsonimportloggingfromcopyimportdeepcopyfromimportlib.metadataimportversionfromtypingimportTYPE_CHECKING,Any,Dict,List,Optional,Unionfromlangchain_core.documentsimportDocumentfromlangchain_core.language_models.chat_modelsimportBaseChatModelfromlangchain_core.messagesimportAIMessagefromlangchain_core.prompts.chatimportChatPromptTemplate,SystemMessagePromptTemplatefromlangchain_core.runnablesimportRunnableSequencefrompymongoimportMongoClient,UpdateOnefrompymongo.collectionimportCollectionfrompymongo.driver_infoimportDriverInfofrompymongo.errorsimportOperationFailurefrompymongo.resultsimportBulkWriteResultfromlangchain_mongodb.graphragimportexample_templates,promptsfrom.promptsimportrag_promptfrom.schemaimportentity_schemaifTYPE_CHECKING:try:fromtypingimportTypeAlias# Python 3.10+exceptImportError:fromtyping_extensionsimportTypeAlias# Python 3.9 fallbackEntity:TypeAlias=Dict[str,Any]"""Represents an Entity in the knowledge graph with specific schema. See .schema"""logger=logging.getLogger(__name__)
[docs]classMongoDBGraphStore:"""GraphRAG DataStore GraphRAG is a ChatModel that provides responses to semantic queries based on a Knowledge Graph that an LLM is used to create. As in Vector RAG, we augment the Chat Model's training data with relevant information that we collect from documents. In Vector RAG, one uses an "Embedding" model that converts both the query, and the potentially relevant documents, into vectors, which can then be compared, and the most similar supplied to the Chat Model as context to the query. In Graph RAG, one uses an "Entity-Extraction" model that converts text into Entities and their relationships, a Knowledge Graph. Comparison is done by Graph traversal, finding entities connected to the query prompts. These are then supplied to the Chat Model as context. The main difference is that GraphRAG's output is typically in a structured format. GraphRAG excels in finding links and common entities, even if these come from different articles. It can combine information from distinct sources providing richer context than Vector RAG in certain cases. Here are a few examples of so-called multi-hop questions where GraphRAG excels: - What is the connection between ACME Corporation and GreenTech Ltd.? - Who is leading the SolarGrid Initiative, and what is their role? - Which organizations are participating in the SolarGrid Initiative? - What is John Doeβs role in ACMEβs renewable energy projects? - Which company is headquartered in San Francisco and involved in the SolarGrid Initiative? In Graph RAG, one uses an Entity-Extraction model that interprets text documents that it is given and extracting the query, and the potentially relevant documents, into graphs. These are composed of nodes that are entities (nouns) and edges that are relationships. The idea is that the graph can find connections between entities and hence answer questions that require more than one connection. In MongoDB, Knowledge Graphs are stored in a single Collection. Each MongoDB Document represents a single entity (node), and it relationships (edges) are defined in a nested field named "relationships". The schema, and an example, are described in the :data:`~langchain_mongodb.graphrag.prompts.entity_context` prompts module. When a query is made, the model extracts the entities in it, then traverses the graph to find connections. The closest entities and their relationships form the context that is included with the query to the Chat Model. Consider this example Query: "Does John Doe work at MongoDB?" GraphRAG can answer this question even if the following two statements come from completely different sources. - "Jane Smith works with John Doe." - "Jane Smith works at MongoDB." """
[docs]def__init__(self,*,connection_string:Optional[str]=None,database_name:Optional[str]=None,collection_name:Optional[str]=None,collection:Optional[Collection]=None,entity_extraction_model:BaseChatModel,entity_prompt:ChatPromptTemplate=None,query_prompt:ChatPromptTemplate=None,max_depth:int=2,allowed_entity_types:List[str]=None,allowed_relationship_types:List[str]=None,entity_examples:str=None,entity_name_examples:str=None,validate:bool=False,validation_action:str="warn",):""" Args: connection_string: A valid MongoDB connection URI. database_name: The name of the database to connect to. collection_name: The name of the collection to connect to. collection: A Collection that will represent a Knowledge Graph. ** One may pass a Collection in lieu of connection_string, database_name, and collection_name. entity_extraction_model: LLM for converting documents into Graph of Entities and Relationships. entity_prompt: Prompt to fill graph store with entities following schema. Defaults to .prompts.ENTITY_EXTRACTION_INSTRUCTIONS query_prompt: Prompt extracts entities and relationships as search starting points. Defaults to .prompts.NAME_EXTRACTION_INSTRUCTIONS max_depth: Maximum recursion depth in graph traversal. allowed_entity_types: If provided, constrains search to these types. allowed_relationship_types: If provided, constrains search to these types. entity_examples: A string containing any number of additional examples to provide as context for entity extraction. entity_name_examples: A string appended to prompts.NAME_EXTRACTION_INSTRUCTIONS containing examples. validate: If True, entity schema will be validated on every insert or update. validation_action: One of {"warn", "error"}. - If "warn", the default, documents will be inserted but errors logged. - If "error", an exception will be raised if any document does not match the schema. """self._schema=deepcopy(entity_schema)collection_existed=Trueifconnection_stringandcollectionisnotNone:raiseValueError("Pass one of: connection_string, database_name, and collection_name""OR a MongoDB Collection.")ifcollectionisNone:# collection is specified by uri and namesclient:MongoClient=MongoClient(connection_string,driver=DriverInfo(name="Langchain",version=version("langchain-mongodb")),)db=client[database_name]ifcollection_namenotindb.list_collection_names():validator={"$jsonSchema":self._schema}ifvalidateelseNonecollection=client[database_name].create_collection(collection_name,validator=validator,validationAction=validation_action,)collection_existed=Falseelse:collection=db[collection_name]else:ifnotisinstance(collection,Collection):raiseValueError("collection must be a MongoDB Collection. ""Consider using connection_string, database_name, and collection_name.")ifvalidateandcollection_existed:# first check for existing validatorcollection_info=collection.database.command("listCollections",filter={"name":collection.name})collection_options=collection_info.get("cursor",{}).get("firstBatch",[])validator=collection_options[0].get("options",{}).get("validator",None)ifnotvalidator:try:collection.database.command("collMod",collection.name,validator={"$jsonSchema":self._schema},validationAction=validation_action,)exceptOperationFailure:logger.warning("Validation will NOT be performed. ""User must be DB Admin to add validation **after** a Collection is created. \n""Please add validator when you create collection: ""db.create_collection.(coll_name, validator={'$jsonSchema': schema.entity_schema})")self.collection=collectionself.entity_extraction_model=entity_extraction_modelself.entity_prompt=(prompts.entity_promptifentity_promptisNoneelseentity_prompt)self.query_prompt=(prompts.query_promptifquery_promptisNoneelsequery_prompt)self.max_depth=max_depthself._schema=deepcopy(entity_schema)ifallowed_entity_types:self.allowed_entity_types=allowed_entity_typesself._schema["properties"]["type"]["enum"]=allowed_entity_typeselse:self.allowed_entity_types=[]ifallowed_relationship_types:# Update Promptself.allowed_relationship_types=allowed_relationship_types# Update schema. Disallow other keys..self._schema["properties"]["relationships"]["properties"]["types"]["enum"]=allowed_relationship_typeselse:self.allowed_relationship_types=[]# Include examplesifentity_examplesisNone:entity_examples=example_templates.entity_extractionself.entity_prompt.messages.insert(1,SystemMessagePromptTemplate.from_template(entity_examples),)ifentity_name_examples:self.query_prompt.messages.insert(1,SystemMessagePromptTemplate.from_template(f"## Additional Examples \n{entity_name_examples}"),)
@propertydefentity_schema(self):"""JSON Schema Object of Entities. Will be applied if validate is True. See Also: `$jsonSchema <https://www.mongodb.com/docs/manual/reference/operator/query/jsonSchema/>`_ """returnself._schema
[docs]@classmethoddeffrom_connection_string(cls,connection_string:str,database_name:str,collection_name:str,entity_extraction_model:BaseChatModel,entity_prompt:ChatPromptTemplate=prompts.entity_prompt,query_prompt:ChatPromptTemplate=prompts.query_prompt,max_depth:int=2,allowed_entity_types:List[str]=None,allowed_relationship_types:List[str]=None,entity_examples:str=None,entity_name_examples:str=None,validate:bool=False,validation_action:str="warn",)->MongoDBGraphStore:"""Construct a `MongoDB KnowLedge Graph for RAG` from a MongoDB connection URI. Args: connection_string: A valid MongoDB connection URI. database_name: The name of the database to connect to. collection_name: The name of the collection to connect to. entity_extraction_model: LLM for converting documents into Graph of Entities and Relationships. entity_prompt: Prompt to fill graph store with entities following schema. query_prompt: Prompt extracts entities and relationships as search starting points. max_depth: Maximum recursion depth in graph traversal. allowed_entity_types: If provided, constrains search to these types. allowed_relationship_types: If provided, constrains search to these types. entity_examples: A string containing any number of additional examples to provide as context for entity extraction. entity_name_examples: A string appended to prompts.NAME_EXTRACTION_INSTRUCTIONS containing examples. validate: If True, entity schema will be validated on every insert or update. validation_action: One of {"warn", "error"}. - If "warn", the default, documents will be inserted but errors logged. - If "error", an exception will be raised if any document does not match the schema. Returns: A new MongoDBGraphStore instance. """client:MongoClient=MongoClient(connection_string,driver=DriverInfo(name="Langchain",version=version("langchain-mongodb")),)collection=client[database_name].create_collection(collection_name)returncls(collection,entity_extraction_model,entity_prompt,query_prompt,max_depth,allowed_entity_types,allowed_relationship_types,entity_examples,entity_name_examples,validate,validation_action,)
[docs]defclose(self)->None:"""Close the resources used by the MongoDBGraphStore."""self.collection.database.client.close()
def_write_entities(self,entities:List[Entity])->BulkWriteResult:"""Isolate logic to insert and aggregate entities."""operations=[]forentityinentities:relationships=entity.get("relationships",{})target_ids=relationships.get("target_ids",[])types=relationships.get("types",[])attributes=relationships.get("attributes",[])# Ensure the lengths of target_ids, types, and attributes alignifnot(len(target_ids)==len(types)==len(attributes)):logger.warning(f"Targets, types, and attributes do not have the same length for {entity['_id']}!")operations.append(UpdateOne(filter={"_id":entity["_id"]},# Match on _idupdate={"$setOnInsert":{# Set if upsert"_id":entity["_id"],"type":entity["type"],},"$addToSet":{# Update without overwriting**{f"attributes.{k}":{"$each":v}fork,vinentity.get("attributes",{}).items()},},"$push":{# Push new entries into arrays"relationships.target_ids":{"$each":target_ids},"relationships.types":{"$each":types},"relationships.attributes":{"$each":attributes},},},upsert=True,))# Execute bulk write for the entitiesifoperations:returnself.collection.bulk_write(operations)
[docs]defadd_documents(self,documents:Union[Document,List[Document]])->List[BulkWriteResult]:"""Extract entities and upsert into the collection. Each entity is represented by a single MongoDB Document. Existing entities identified in documents will be updated. Args: documents: list of textual documents and associated metadata. Returns: List containing metadata on entities inserted and updated, one value for each input document. """documents=[documents]ifisinstance(documents,Document)elsedocumentsresults=[]fordocindocuments:# Call LLM to find all Entities in docentities=self.extract_entities(doc.page_content)logger.debug(f"Entities found: {[e['_id']foreinentities]}")# Insert new or combine with existing entitiesresults.append(self._write_entities(entities))returnresults
[docs]defextract_entities(self,raw_document:str,**kwargs:Any)->List[Entity]:"""Extract entities and their relations using chosen prompt and LLM. Args: raw_document: A single text document as a string. Typically prose. Returns: List of Entity dictionaries. """# Combine the LLM with the prompt template to form a chainchain:RunnableSequence=self.entity_prompt|self.entity_extraction_model# Invoke on a document to extract entities and relationshipsresponse:AIMessage=chain.invoke(dict(input_document=raw_document,entity_schema=self.entity_schema,allowed_entity_types=self.allowed_entity_types,allowed_relationship_types=self.allowed_relationship_types,))# Post-Process output string into list of entity json documents# Strip the ```json prefix and trailing ```json_string=(response.content.removeprefix("```json").removesuffix("```").strip())extracted=json.loads(json_string)returnextracted["entities"]
[docs]defextract_entity_names(self,raw_document:str,**kwargs:Any)->List[str]:"""Extract entity names from a document for similarity_search. The second entity extraction has a different form and purpose than the first as we are looking for starting points of our search and paths to follow. We aim to find source nodes, but no target nodes or edges. Args: raw_document: A single text document as a string. Typically prose. Returns: List of entity names / _ids. """# Combine the llm with the prompt template to form a chainchain:RunnableSequence=self.query_prompt|self.entity_extraction_model# Invoke on a document to extract entities and relationshipsresponse:AIMessage=chain.invoke(dict(input_document=raw_document,allowed_entity_types=self.allowed_entity_types,))# Post-Process output string into list of entity json documents# Strip the ```json prefix and suffixjson_string=(response.content.removeprefix("```json").removesuffix("```").strip())returnjson.loads(json_string)
[docs]deffind_entity_by_name(self,name:str)->Optional[Entity]:"""Utility to get Entity dict from Knowledge Graph / Collection. Args: name: _id string to look for. Returns: List of Entity dicts if any match name. """returnself.collection.find_one({"_id":name})
[docs]defrelated_entities(self,starting_entities:List[str],max_depth:Optional[int]=None,)->List[Entity]:"""Traverse Graph along relationship edges to find connected entities. Args: starting_entities: Traversal begins with documents whose _id fields match these strings. max_depth: Recursion continues until no more matching documents are found, or until the operation reaches a recursion depth specified by this parameter. Returns: List of connected entities. """pipeline=[# Match starting entities{"$match":{"_id":{"$in":starting_entities}}},{"$graphLookup":{"from":self.collection.name,"startWith":"$relationships.target_ids",# Start traversal with relationships.target_ids"connectFromField":"relationships.target_ids",# Traverse via relationships.target_ids"connectToField":"_id",# Match to entity _id field"as":"connections",# Store connections"maxDepth":max_depthorself.max_depth,# Limit traversal depth"depthField":"depth",# Track depth}},# Exclude connections from the original document{"$project":{"_id":0,"original":{"_id":"$_id","type":"$type","attributes":"$attributes","relationships":"$relationships",},"connections":1,# Retain connections for deduplication}},# Combine original and connections into one array{"$project":{"combined":{"$concatArrays":[["$original"],# Include original as an array"$connections",# Include connections]}}},# Unwind the combined array into individual documents{"$unwind":"$combined"},# Remove duplicates by grouping on `_id` and keeping the first document{"$group":{"_id":"$combined._id",# Group by entity _id"entity":{"$first":"$combined"},# Keep the first occurrence}},# Format the final output{"$replaceRoot":{"newRoot":"$entity"# Use the deduplicated document as the root}},]returnlist(self.collection.aggregate(pipeline))
[docs]defsimilarity_search(self,input_document:str)->List[Entity]:"""Retrieve list of connected Entities found via traversal of KnowledgeGraph. 1. Use LLM & Prompt to find entities within the input_document itself. 2. Find Entity Nodes that match those found in the input_document. 3. Traverse the graph using these as starting points. Args: input_document: String to find relevant documents for. Returns: List of connected Entity dictionaries. """starting_ids:List[str]=self.extract_entity_names(input_document)returnself.related_entities(starting_ids)
[docs]defchat_response(self,query:str,chat_model:BaseChatModel=None,prompt:ChatPromptTemplate=None,)->AIMessage:"""Responds to a query given information found in Knowledge Graph. Args: query: Prompt before it is augmented by Knowledge Graph. chat_model: ChatBot. Defaults to entity_extraction_model. prompt: Alternative Prompt Template. Defaults to prompts.rag_prompt. Returns: Response Message. response.content contains text. """ifchat_modelisNone:chat_model=self.entity_extraction_modelifpromptisNone:prompt=rag_prompt# Perform Retrieval on knowledge graphrelated_entities=self.similarity_search(query)# Combine the LLM with the prompt template to form a chainchain:RunnableSequence=prompt|chat_model# Invoke with queryreturnchain.invoke(dict(query=query,related_entities=related_entities,entity_schema=entity_schema,))