from typing import Any, Optional
from langchain_core.runnables.graph import Graph, LabelsDict
[docs]
class PngDrawer:
"""Helper class to draw a state graph into a PNG file.
It requires `graphviz` and `pygraphviz` to be installed.
:param fontname: The font to use for the labels
:param labels: A dictionary of label overrides. The dictionary
should have the following format:
{
"nodes": {
"node1": "CustomLabel1",
"node2": "CustomLabel2",
"__end__": "End Node"
},
"edges": {
"continue": "ContinueLabel",
"end": "EndLabel"
}
}
The keys are the original labels, and the values are the new labels.
Usage:
drawer = PngDrawer()
drawer.draw(state_graph, 'graph.png')
"""
[docs]
def __init__(
self, fontname: Optional[str] = None, labels: Optional[LabelsDict] = None
) -> None:
"""Initializes the PNG drawer.
Args:
fontname: The font to use for the labels. Defaults to "arial".
labels: A dictionary of label overrides. The dictionary
should have the following format:
{
"nodes": {
"node1": "CustomLabel1",
"node2": "CustomLabel2",
"__end__": "End Node"
},
"edges": {
"continue": "ContinueLabel",
"end": "EndLabel"
}
}
The keys are the original labels, and the values are the new labels.
Defaults to None.
"""
self.fontname = fontname or "arial"
self.labels = labels or LabelsDict(nodes={}, edges={})
[docs]
def get_node_label(self, label: str) -> str:
"""Returns the label to use for a node.
Args:
label: The original label.
Returns:
The new label.
"""
label = self.labels.get("nodes", {}).get(label, label)
return f"<<B>{label}</B>>"
[docs]
def get_edge_label(self, label: str) -> str:
"""Returns the label to use for an edge.
Args:
label: The original label.
Returns:
The new label.
"""
label = self.labels.get("edges", {}).get(label, label)
return f"<<U>{label}</U>>"
[docs]
def add_node(self, viz: Any, node: str) -> None:
"""Adds a node to the graph.
Args:
viz: The graphviz object.
node: The node to add.
Returns:
None
"""
viz.add_node(
node,
label=self.get_node_label(node),
style="filled",
fillcolor="yellow",
fontsize=15,
fontname=self.fontname,
)
[docs]
def add_edge(
self,
viz: Any,
source: str,
target: str,
label: Optional[str] = None,
conditional: bool = False,
) -> None:
"""Adds an edge to the graph.
Args:
viz: The graphviz object.
source: The source node.
target: The target node.
label: The label for the edge. Defaults to None.
conditional: Whether the edge is conditional. Defaults to False.
Returns:
None
"""
viz.add_edge(
source,
target,
label=self.get_edge_label(label) if label else "",
fontsize=12,
fontname=self.fontname,
style="dotted" if conditional else "solid",
)
[docs]
def draw(self, graph: Graph, output_path: Optional[str] = None) -> Optional[bytes]:
"""Draw the given state graph into a PNG file.
Requires `graphviz` and `pygraphviz` to be installed.
:param graph: The graph to draw
:param output_path: The path to save the PNG. If None, PNG bytes are returned.
"""
try:
import pygraphviz as pgv # type: ignore[import]
except ImportError as exc:
msg = "Install pygraphviz to draw graphs: `pip install pygraphviz`."
raise ImportError(msg) from exc
# Create a directed graph
viz = pgv.AGraph(directed=True, nodesep=0.9, ranksep=1.0)
# Add nodes, conditional edges, and edges to the graph
self.add_nodes(viz, graph)
self.add_edges(viz, graph)
# Update entrypoint and END styles
self.update_styles(viz, graph)
# Save the graph as PNG
try:
return viz.draw(output_path, format="png", prog="dot")
finally:
viz.close()
[docs]
def add_nodes(self, viz: Any, graph: Graph) -> None:
"""Add nodes to the graph.
Args:
viz: The graphviz object.
graph: The graph to draw.
"""
for node in graph.nodes:
self.add_node(viz, node)
[docs]
def add_edges(self, viz: Any, graph: Graph) -> None:
"""Add edges to the graph.
Args:
viz: The graphviz object.
graph: The graph to draw.
"""
for start, end, data, cond in graph.edges:
self.add_edge(
viz, start, end, str(data) if data is not None else None, cond
)
[docs]
def update_styles(self, viz: Any, graph: Graph) -> None:
"""Update the styles of the entrypoint and END nodes.
Args:
viz: The graphviz object.
graph: The graph to draw.
"""
if first := graph.first_node():
viz.get_node(first.id).attr.update(fillcolor="lightblue")
if last := graph.last_node():
viz.get_node(last.id).attr.update(fillcolor="orange")