[docs]classPngDrawer:"""Helper class to draw a state graph into a PNG file. It requires `graphviz` and `pygraphviz` to be installed. :param fontname: The font to use for the labels :param labels: A dictionary of label overrides. The dictionary should have the following format: { "nodes": { "node1": "CustomLabel1", "node2": "CustomLabel2", "__end__": "End Node" }, "edges": { "continue": "ContinueLabel", "end": "EndLabel" } } The keys are the original labels, and the values are the new labels. Usage: drawer = PngDrawer() drawer.draw(state_graph, 'graph.png') """
[docs]def__init__(self,fontname:Optional[str]=None,labels:Optional[LabelsDict]=None)->None:"""Initializes the PNG drawer. Args: fontname: The font to use for the labels. Defaults to "arial". labels: A dictionary of label overrides. The dictionary should have the following format: { "nodes": { "node1": "CustomLabel1", "node2": "CustomLabel2", "__end__": "End Node" }, "edges": { "continue": "ContinueLabel", "end": "EndLabel" } } The keys are the original labels, and the values are the new labels. Defaults to None. """self.fontname=fontnameor"arial"self.labels=labelsorLabelsDict(nodes={},edges={})
[docs]defget_node_label(self,label:str)->str:"""Returns the label to use for a node. Args: label: The original label. Returns: The new label. """label=self.labels.get("nodes",{}).get(label,label)returnf"<<B>{label}</B>>"
[docs]defget_edge_label(self,label:str)->str:"""Returns the label to use for an edge. Args: label: The original label. Returns: The new label. """label=self.labels.get("edges",{}).get(label,label)returnf"<<U>{label}</U>>"
[docs]defadd_node(self,viz:Any,node:str)->None:"""Adds a node to the graph. Args: viz: The graphviz object. node: The node to add. Returns: None """viz.add_node(node,label=self.get_node_label(node),style="filled",fillcolor="yellow",fontsize=15,fontname=self.fontname,)
[docs]defadd_edge(self,viz:Any,source:str,target:str,label:Optional[str]=None,conditional:bool=False,)->None:"""Adds an edge to the graph. Args: viz: The graphviz object. source: The source node. target: The target node. label: The label for the edge. Defaults to None. conditional: Whether the edge is conditional. Defaults to False. Returns: None """viz.add_edge(source,target,label=self.get_edge_label(label)iflabelelse"",fontsize=12,fontname=self.fontname,style="dotted"ifconditionalelse"solid",)
[docs]defdraw(self,graph:Graph,output_path:Optional[str]=None)->Optional[bytes]:"""Draw the given state graph into a PNG file. Requires `graphviz` and `pygraphviz` to be installed. :param graph: The graph to draw :param output_path: The path to save the PNG. If None, PNG bytes are returned. """try:importpygraphvizaspgv# type: ignore[import]exceptImportErrorasexc:msg="Install pygraphviz to draw graphs: `pip install pygraphviz`."raiseImportError(msg)fromexc# Create a directed graphviz=pgv.AGraph(directed=True,nodesep=0.9,ranksep=1.0)# Add nodes, conditional edges, and edges to the graphself.add_nodes(viz,graph)self.add_edges(viz,graph)# Update entrypoint and END stylesself.update_styles(viz,graph)# Save the graph as PNGtry:returnviz.draw(output_path,format="png",prog="dot")finally:viz.close()
[docs]defadd_nodes(self,viz:Any,graph:Graph)->None:"""Add nodes to the graph. Args: viz: The graphviz object. graph: The graph to draw. """fornodeingraph.nodes:self.add_node(viz,node)
[docs]defadd_edges(self,viz:Any,graph:Graph)->None:"""Add edges to the graph. Args: viz: The graphviz object. graph: The graph to draw. """forstart,end,data,condingraph.edges:self.add_edge(viz,start,end,str(data)ifdataisnotNoneelseNone,cond)
[docs]defupdate_styles(self,viz:Any,graph:Graph)->None:"""Update the styles of the entrypoint and END nodes. Args: viz: The graphviz object. graph: The graph to draw. """iffirst:=graph.first_node():viz.get_node(first.id).attr.update(fillcolor="lightblue")iflast:=graph.last_node():viz.get_node(last.id).attr.update(fillcolor="orange")