Source code for langchain_neo4j.chains.graph_qa.cypher_utils

import re
from collections import namedtuple
from typing import Any, Dict, List, Optional, Tuple

Schema = namedtuple("Schema", ["left_node", "relation", "right_node"])


[docs] class CypherQueryCorrector: """ Used to correct relationship direction in generated Cypher statements. This code is copied from the winner's submission to the Cypher competition: https://github.com/sakusaku-rich/cypher-direction-competition """ property_pattern = re.compile(r"\{.+?\}") node_pattern = re.compile(r"\(.+?\)") path_pattern = re.compile( r"(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))(<?-)(\[.*?\])?(->?)(\([^\,\(\)]*?(\{.+\})?[^\,\(\)]*?\))" ) node_relation_node_pattern = re.compile( r"(\()+(?P<left_node>[^()]*?)\)(?P<relation>.*?)\((?P<right_node>[^()]*?)(\))+" ) relation_type_pattern = re.compile(r":(?P<relation_type>.+?)?(\{.+\})?]")
[docs] def __init__(self, schemas: List[Schema]): """ Args: schemas: list of schemas """ self.schemas = schemas
[docs] def clean_node(self, node: str) -> str: """ Args: node: node in string format """ node = re.sub(self.property_pattern, "", node) node = node.replace("(", "") node = node.replace(")", "") node = node.strip() return node
[docs] def detect_node_variables(self, query: str) -> Dict[str, List[str]]: """ Args: query: cypher query """ nodes = re.findall(self.node_pattern, query) nodes = [self.clean_node(node) for node in nodes] res: Dict[str, Any] = {} for node in nodes: parts = node.split(":") if parts == "": continue variable = parts[0] if variable not in res: res[variable] = [] res[variable] += parts[1:] return res
[docs] def extract_paths(self, query: str) -> "List[str]": """ Args: query: cypher query """ paths = [] idx = 0 while matched := self.path_pattern.findall(query[idx:]): matched = matched[0] matched = [ m for i, m in enumerate(matched) if i not in [1, len(matched) - 1] ] path = "".join(matched) idx = query.find(path) + len(path) - len(matched[-1]) paths.append(path) return paths
[docs] def judge_direction(self, relation: str) -> str: """ Args: relation: relation in string format """ direction = "BIDIRECTIONAL" if relation[0] == "<": direction = "INCOMING" if relation[-1] == ">": direction = "OUTGOING" return direction
[docs] def extract_node_variable(self, part: str) -> Optional[str]: """ Args: part: node in string format """ part = part.lstrip("(").rstrip(")") idx = part.find(":") if idx != -1: part = part[:idx] return None if part == "" else part
[docs] def detect_labels( self, str_node: str, node_variable_dict: Dict[str, Any] ) -> List[str]: """ Args: str_node: node in string format node_variable_dict: dictionary of node variables """ splitted_node = str_node.split(":") variable = splitted_node[0] labels = [] if variable in node_variable_dict: labels = node_variable_dict[variable] elif variable == "" and len(splitted_node) > 1: labels = splitted_node[1:] return labels
[docs] def verify_schema( self, from_node_labels: List[str], relation_types: List[str], to_node_labels: List[str], ) -> bool: """ Args: from_node_labels: labels of the from node relation_type: type of the relation to_node_labels: labels of the to node """ valid_schemas = self.schemas if from_node_labels != []: from_node_labels = [label.strip("`") for label in from_node_labels] valid_schemas = [ schema for schema in valid_schemas if schema[0] in from_node_labels ] if to_node_labels != []: to_node_labels = [label.strip("`") for label in to_node_labels] valid_schemas = [ schema for schema in valid_schemas if schema[2] in to_node_labels ] if relation_types != []: relation_types = [type.strip("`") for type in relation_types] valid_schemas = [ schema for schema in valid_schemas if schema[1] in relation_types ] return valid_schemas != []
[docs] def detect_relation_types(self, str_relation: str) -> Tuple[str, List[str]]: """ Args: str_relation: relation in string format """ relation_direction = self.judge_direction(str_relation) relation_type = self.relation_type_pattern.search(str_relation) if relation_type is None or relation_type.group("relation_type") is None: return relation_direction, [] relation_types = [ t.strip().strip("!") for t in relation_type.group("relation_type").split("|") ] return relation_direction, relation_types
[docs] def correct_query(self, query: str) -> str: """ Args: query: cypher query """ node_variable_dict = self.detect_node_variables(query) paths = self.extract_paths(query) for path in paths: original_path = path start_idx = 0 while start_idx < len(path): match_res = re.match(self.node_relation_node_pattern, path[start_idx:]) if match_res is None: break start_idx += match_res.start() match_dict = match_res.groupdict() left_node_labels = self.detect_labels( match_dict["left_node"], node_variable_dict ) right_node_labels = self.detect_labels( match_dict["right_node"], node_variable_dict ) end_idx = ( start_idx + 4 + len(match_dict["left_node"]) + len(match_dict["relation"]) + len(match_dict["right_node"]) ) original_partial_path = original_path[start_idx : end_idx + 1] relation_direction, relation_types = self.detect_relation_types( match_dict["relation"] ) if relation_types != [] and "".join(relation_types).find("*") != -1: start_idx += ( len(match_dict["left_node"]) + len(match_dict["relation"]) + 2 ) continue if relation_direction == "OUTGOING": is_legal = self.verify_schema( left_node_labels, relation_types, right_node_labels ) if not is_legal: is_legal = self.verify_schema( right_node_labels, relation_types, left_node_labels ) if is_legal: corrected_relation = "<" + match_dict["relation"][:-1] corrected_partial_path = original_partial_path.replace( match_dict["relation"], corrected_relation ) query = query.replace( original_partial_path, corrected_partial_path ) else: return "" elif relation_direction == "INCOMING": is_legal = self.verify_schema( right_node_labels, relation_types, left_node_labels ) if not is_legal: is_legal = self.verify_schema( left_node_labels, relation_types, right_node_labels ) if is_legal: corrected_relation = match_dict["relation"][1:] + ">" corrected_partial_path = original_partial_path.replace( match_dict["relation"], corrected_relation ) query = query.replace( original_partial_path, corrected_partial_path ) else: return "" else: is_legal = self.verify_schema( left_node_labels, relation_types, right_node_labels ) is_legal |= self.verify_schema( right_node_labels, relation_types, left_node_labels ) if not is_legal: return "" start_idx += ( len(match_dict["left_node"]) + len(match_dict["relation"]) + 2 ) return query
def __call__(self, query: str) -> str: """Correct the query to make it valid. If Args: query: cypher query """ return self.correct_query(query)