Source code for langchain_core.runnables.graph_mermaid

"""Mermaid graph drawing utilities."""

from __future__ import annotations

import asyncio
import base64
import random
import re
import string
import time
from dataclasses import asdict
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal, Optional

import yaml

from langchain_core.runnables.graph import (
    CurveStyle,
    MermaidDrawMethod,
    NodeStyles,
)

if TYPE_CHECKING:
    from langchain_core.runnables.graph import Edge, Node


try:
    import requests

    _HAS_REQUESTS = True
except ImportError:
    _HAS_REQUESTS = False

try:
    from pyppeteer import launch  # type: ignore[import-not-found]

    _HAS_PYPPETEER = True
except ImportError:
    _HAS_PYPPETEER = False

MARKDOWN_SPECIAL_CHARS = "*_`"


[docs] def draw_mermaid( nodes: dict[str, Node], edges: list[Edge], *, first_node: Optional[str] = None, last_node: Optional[str] = None, with_styles: bool = True, curve_style: CurveStyle = CurveStyle.LINEAR, node_styles: Optional[NodeStyles] = None, wrap_label_n_words: int = 9, frontmatter_config: Optional[dict[str, Any]] = None, ) -> str: """Draws a Mermaid graph using the provided graph data. Args: nodes (dict[str, str]): List of node ids. edges (list[Edge]): List of edges, object with a source, target and data. first_node (str, optional): Id of the first node. Defaults to None. last_node (str, optional): Id of the last node. Defaults to None. with_styles (bool, optional): Whether to include styles in the graph. Defaults to True. curve_style (CurveStyle, optional): Curve style for the edges. Defaults to CurveStyle.LINEAR. node_styles (NodeStyles, optional): Node colors for different types. Defaults to NodeStyles(). wrap_label_n_words (int, optional): Words to wrap the edge labels. Defaults to 9. frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config. Can be used to customize theme and styles. Will be converted to YAML and added to the beginning of the mermaid graph. Defaults to None. See more here: https://mermaid.js.org/config/configuration.html. Example config: .. code-block:: python { "config": { "theme": "neutral", "look": "handDrawn", "themeVariables": { "primaryColor": "#e2e2e2"}, } } Returns: str: Mermaid graph syntax. """ # Initialize Mermaid graph configuration original_frontmatter_config = frontmatter_config or {} original_flowchart_config = original_frontmatter_config.get("config", {}).get( "flowchart", {} ) frontmatter_config = { **original_frontmatter_config, "config": { **original_frontmatter_config.get("config", {}), "flowchart": {**original_flowchart_config, "curve": curve_style.value}, }, } mermaid_graph = ( ( "---\n" + yaml.dump(frontmatter_config, default_flow_style=False) + "---\ngraph TD;\n" ) if with_styles else "graph TD;\n" ) # Group nodes by subgraph subgraph_nodes: dict[str, dict[str, Node]] = {} regular_nodes: dict[str, Node] = {} for key, node in nodes.items(): if ":" in key: # For nodes with colons, add them only to their deepest subgraph level prefix = ":".join(key.split(":")[:-1]) subgraph_nodes.setdefault(prefix, {})[key] = node else: regular_nodes[key] = node # Node formatting templates default_class_label = "default" format_dict = {default_class_label: "{0}({1})"} if first_node is not None: format_dict[first_node] = "{0}([{1}]):::first" if last_node is not None: format_dict[last_node] = "{0}([{1}]):::last" def render_node(key: str, node: Node, indent: str = "\t") -> str: """Helper function to render a node with consistent formatting.""" node_name = node.name.split(":")[-1] label = ( f"<p>{node_name}</p>" if node_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS)) and node_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS)) else node_name ) if node.metadata: label = ( f"{label}<hr/><small><em>" + "\n".join(f"{k} = {value}" for k, value in node.metadata.items()) + "</em></small>" ) node_label = format_dict.get(key, format_dict[default_class_label]).format( _to_safe_id(key), label ) return f"{indent}{node_label}\n" # Add non-subgraph nodes to the graph if with_styles: for key, node in regular_nodes.items(): mermaid_graph += render_node(key, node) # Group edges by their common prefixes edge_groups: dict[str, list[Edge]] = {} for edge in edges: src_parts = edge.source.split(":") tgt_parts = edge.target.split(":") common_prefix = ":".join( src for src, tgt in zip(src_parts, tgt_parts) if src == tgt ) edge_groups.setdefault(common_prefix, []).append(edge) seen_subgraphs = set() def add_subgraph(edges: list[Edge], prefix: str) -> None: nonlocal mermaid_graph self_loop = len(edges) == 1 and edges[0].source == edges[0].target if prefix and not self_loop: subgraph = prefix.rsplit(":", maxsplit=1)[-1] if subgraph in seen_subgraphs: msg = ( f"Found duplicate subgraph '{subgraph}' -- this likely means that " "you're reusing a subgraph node with the same name. " "Please adjust your graph to have subgraph nodes with unique names." ) raise ValueError(msg) seen_subgraphs.add(subgraph) mermaid_graph += f"\tsubgraph {subgraph}\n" # Add nodes that belong to this subgraph if with_styles and prefix in subgraph_nodes: for key, node in subgraph_nodes[prefix].items(): mermaid_graph += render_node(key, node) for edge in edges: source, target = edge.source, edge.target # Add BR every wrap_label_n_words words if edge.data is not None: edge_data = edge.data words = str(edge_data).split() # Split the string into words # Group words into chunks of wrap_label_n_words size if len(words) > wrap_label_n_words: edge_data = "&nbsp<br>&nbsp".join( " ".join(words[i : i + wrap_label_n_words]) for i in range(0, len(words), wrap_label_n_words) ) if edge.conditional: edge_label = f" -. &nbsp;{edge_data}&nbsp; .-> " else: edge_label = f" -- &nbsp;{edge_data}&nbsp; --> " else: edge_label = " -.-> " if edge.conditional else " --> " mermaid_graph += ( f"\t{_to_safe_id(source)}{edge_label}{_to_safe_id(target)};\n" ) # Recursively add nested subgraphs for nested_prefix, edges_ in edge_groups.items(): if not nested_prefix.startswith(prefix + ":") or nested_prefix == prefix: continue # only go to first level subgraphs if ":" in nested_prefix[len(prefix) + 1 :]: continue add_subgraph(edges_, nested_prefix) if prefix and not self_loop: mermaid_graph += "\tend\n" # Start with the top-level edges (no common prefix) add_subgraph(edge_groups.get("", []), "") # Add remaining subgraphs with edges for prefix, edges_ in edge_groups.items(): if not prefix or ":" in prefix: continue add_subgraph(edges_, prefix) seen_subgraphs.add(prefix) # Add empty subgraphs (subgraphs with no internal edges) if with_styles: for prefix, subgraph_node in subgraph_nodes.items(): if ":" not in prefix and prefix not in seen_subgraphs: mermaid_graph += f"\tsubgraph {prefix}\n" # Add nodes that belong to this subgraph for key, node in subgraph_node.items(): mermaid_graph += render_node(key, node) mermaid_graph += "\tend\n" seen_subgraphs.add(prefix) # Add custom styles for nodes if with_styles: mermaid_graph += _generate_mermaid_graph_styles(node_styles or NodeStyles()) return mermaid_graph
def _to_safe_id(label: str) -> str: """Convert a string into a Mermaid-compatible node id. Keep [a-zA-Z0-9_-] characters unchanged. Map every other character -> backslash + lowercase hex codepoint. Result is guaranteed to be unique and Mermaid-compatible, so nodes with special characters always render correctly. """ allowed = string.ascii_letters + string.digits + "_-" out = [ch if ch in allowed else "\\" + format(ord(ch), "x") for ch in label] return "".join(out) def _generate_mermaid_graph_styles(node_colors: NodeStyles) -> str: """Generates Mermaid graph styles for different node types.""" styles = "" for class_name, style in asdict(node_colors).items(): styles += f"\tclassDef {class_name} {style}\n" return styles
[docs] def draw_mermaid_png( mermaid_syntax: str, output_file_path: Optional[str] = None, draw_method: MermaidDrawMethod = MermaidDrawMethod.API, background_color: Optional[str] = "white", padding: int = 10, max_retries: int = 1, retry_delay: float = 1.0, base_url: Optional[str] = None, ) -> bytes: """Draws a Mermaid graph as PNG using provided syntax. Args: mermaid_syntax (str): Mermaid graph syntax. output_file_path (str, optional): Path to save the PNG image. Defaults to None. draw_method (MermaidDrawMethod, optional): Method to draw the graph. Defaults to MermaidDrawMethod.API. background_color (str, optional): Background color of the image. Defaults to "white". padding (int, optional): Padding around the image. Defaults to 10. max_retries (int, optional): Maximum number of retries (MermaidDrawMethod.API). Defaults to 1. retry_delay (float, optional): Delay between retries (MermaidDrawMethod.API). Defaults to 1.0. base_url (str, optional): Base URL for the Mermaid.ink API. Defaults to None. Returns: bytes: PNG image bytes. Raises: ValueError: If an invalid draw method is provided. """ if draw_method == MermaidDrawMethod.PYPPETEER: img_bytes = asyncio.run( _render_mermaid_using_pyppeteer( mermaid_syntax, output_file_path, background_color, padding ) ) elif draw_method == MermaidDrawMethod.API: img_bytes = _render_mermaid_using_api( mermaid_syntax, output_file_path=output_file_path, background_color=background_color, max_retries=max_retries, retry_delay=retry_delay, base_url=base_url, ) else: supported_methods = ", ".join([m.value for m in MermaidDrawMethod]) msg = ( f"Invalid draw method: {draw_method}. " f"Supported draw methods are: {supported_methods}" ) raise ValueError(msg) return img_bytes
async def _render_mermaid_using_pyppeteer( mermaid_syntax: str, output_file_path: Optional[str] = None, background_color: Optional[str] = "white", padding: int = 10, device_scale_factor: int = 3, ) -> bytes: """Renders Mermaid graph using Pyppeteer.""" if not _HAS_PYPPETEER: msg = "Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`." raise ImportError(msg) browser = await launch() page = await browser.newPage() # Setup Mermaid JS await page.goto("about:blank") await page.addScriptTag( {"url": "https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"} ) await page.evaluate( """() => { mermaid.initialize({startOnLoad:true}); }""" ) # Render SVG svg_code = await page.evaluate( """(mermaidGraph) => { return mermaid.mermaidAPI.render('mermaid', mermaidGraph); }""", mermaid_syntax, ) # Set the page background to white await page.evaluate( """(svg, background_color) => { document.body.innerHTML = svg; document.body.style.background = background_color; }""", svg_code["svg"], background_color, ) # Take a screenshot dimensions = await page.evaluate( """() => { const svgElement = document.querySelector('svg'); const rect = svgElement.getBoundingClientRect(); return { width: rect.width, height: rect.height }; }""" ) await page.setViewport( { "width": int(dimensions["width"] + padding), "height": int(dimensions["height"] + padding), "deviceScaleFactor": device_scale_factor, } ) img_bytes = await page.screenshot({"fullPage": False}) await browser.close() if output_file_path is not None: await asyncio.get_event_loop().run_in_executor( None, Path(output_file_path).write_bytes, img_bytes ) return img_bytes def _render_mermaid_using_api( mermaid_syntax: str, *, output_file_path: Optional[str] = None, background_color: Optional[str] = "white", file_type: Optional[Literal["jpeg", "png", "webp"]] = "png", max_retries: int = 1, retry_delay: float = 1.0, base_url: Optional[str] = None, ) -> bytes: """Renders Mermaid graph using the Mermaid.INK API.""" # Defaults to using the public mermaid.ink server. base_url = base_url if base_url is not None else "https://mermaid.ink" if not _HAS_REQUESTS: msg = ( "Install the `requests` module to use the Mermaid.INK API: " "`pip install requests`." ) raise ImportError(msg) # Use Mermaid API to render the image mermaid_syntax_encoded = base64.b64encode(mermaid_syntax.encode("utf8")).decode( "ascii" ) # Check if the background color is a hexadecimal color code using regex if background_color is not None: hex_color_pattern = re.compile(r"^#(?:[0-9a-fA-F]{3}){1,2}$") if not hex_color_pattern.match(background_color): background_color = f"!{background_color}" image_url = ( f"{base_url}/img/{mermaid_syntax_encoded}" f"?type={file_type}&bgColor={background_color}" ) error_msg_suffix = ( "To resolve this issue:\n" "1. Check your internet connection and try again\n" "2. Try with higher retry settings: " "`draw_mermaid_png(..., max_retries=5, retry_delay=2.0)`\n" "3. Use the Pyppeteer rendering method which will render your graph locally " "in a browser: `draw_mermaid_png(..., draw_method=MermaidDrawMethod.PYPPETEER)`" ) for attempt in range(max_retries + 1): try: response = requests.get(image_url, timeout=10) if response.status_code == requests.codes.ok: img_bytes = response.content if output_file_path is not None: Path(output_file_path).write_bytes(response.content) return img_bytes # If we get a server error (5xx), retry if 500 <= response.status_code < 600 and attempt < max_retries: # Exponential backoff with jitter sleep_time = retry_delay * (2**attempt) * (0.5 + 0.5 * random.random()) # noqa: S311 not used for crypto time.sleep(sleep_time) continue # For other status codes, fail immediately msg = ( f"Failed to reach {base_url} API while trying to render " f"your graph. Status code: {response.status_code}.\n\n" ) + error_msg_suffix raise ValueError(msg) except (requests.RequestException, requests.Timeout) as e: if attempt < max_retries: # Exponential backoff with jitter sleep_time = retry_delay * (2**attempt) * (0.5 + 0.5 * random.random()) # noqa: S311 not used for crypto time.sleep(sleep_time) else: msg = ( f"Failed to reach {base_url} API while trying to render " f"your graph after {max_retries} retries. " ) + error_msg_suffix raise ValueError(msg) from e # This should not be reached, but just in case msg = ( f"Failed to reach {base_url} API while trying to render " f"your graph after {max_retries} retries. " ) + error_msg_suffix raise ValueError(msg)