[docs]defdraw_mermaid(nodes:dict[str,Node],edges:list[Edge],*,first_node:Optional[str]=None,last_node:Optional[str]=None,with_styles:bool=True,curve_style:CurveStyle=CurveStyle.LINEAR,node_styles:Optional[NodeStyles]=None,wrap_label_n_words:int=9,frontmatter_config:Optional[dict[str,Any]]=None,)->str:"""Draws a Mermaid graph using the provided graph data. Args: nodes (dict[str, str]): List of node ids. edges (List[Edge]): List of edges, object with a source, target and data. first_node (str, optional): Id of the first node. Defaults to None. last_node (str, optional): Id of the last node. Defaults to None. with_styles (bool, optional): Whether to include styles in the graph. Defaults to True. curve_style (CurveStyle, optional): Curve style for the edges. Defaults to CurveStyle.LINEAR. node_styles (NodeStyles, optional): Node colors for different types. Defaults to NodeStyles(). wrap_label_n_words (int, optional): Words to wrap the edge labels. Defaults to 9. frontmatter_config (dict[str, Any], optional): Mermaid frontmatter config. Can be used to customize theme and styles. Will be converted to YAML and added to the beginning of the mermaid graph. Defaults to None. See more here: https://mermaid.js.org/config/configuration.html. Example: ```python { "config": { "theme": "neutral", "look": "handDrawn", "themeVariables": { "primaryColor": "#e2e2e2"}, } } ``` Returns: str: Mermaid graph syntax. """# Initialize Mermaid graph configurationoriginal_frontmatter_config=frontmatter_configor{}original_flowchart_config=original_frontmatter_config.get("config",{}).get("flowchart",{})frontmatter_config={**original_frontmatter_config,"config":{**original_frontmatter_config.get("config",{}),"flowchart":{**original_flowchart_config,"curve":curve_style.value},},}mermaid_graph=(("---\n"+yaml.dump(frontmatter_config,default_flow_style=False)+"---\ngraph TD;\n")ifwith_styleselse"graph TD;\n")# Group nodes by subgraphsubgraph_nodes:dict[str,dict[str,Node]]={}regular_nodes:dict[str,Node]={}forkey,nodeinnodes.items():if":"inkey:# For nodes with colons, add them only to their deepest subgraph levelprefix=":".join(key.split(":")[:-1])subgraph_nodes.setdefault(prefix,{})[key]=nodeelse:regular_nodes[key]=node# Node formatting templatesdefault_class_label="default"format_dict={default_class_label:"{0}({1})"}iffirst_nodeisnotNone:format_dict[first_node]="{0}([{1}]):::first"iflast_nodeisnotNone:format_dict[last_node]="{0}([{1}]):::last"defrender_node(key:str,node:Node,indent:str="\t")->str:"""Helper function to render a node with consistent formatting."""node_name=node.name.split(":")[-1]label=(f"<p>{node_name}</p>"ifnode_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS))andnode_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS))elsenode_name)ifnode.metadata:label=(f"{label}<hr/><small><em>"+"\n".join(f"{k} = {value}"fork,valueinnode.metadata.items())+"</em></small>")node_label=format_dict.get(key,format_dict[default_class_label]).format(_escape_node_label(key),label)returnf"{indent}{node_label}\n"# Add non-subgraph nodes to the graphifwith_styles:forkey,nodeinregular_nodes.items():mermaid_graph+=render_node(key,node)# Group edges by their common prefixesedge_groups:dict[str,list[Edge]]={}foredgeinedges:src_parts=edge.source.split(":")tgt_parts=edge.target.split(":")common_prefix=":".join(srcforsrc,tgtinzip(src_parts,tgt_parts)ifsrc==tgt)edge_groups.setdefault(common_prefix,[]).append(edge)seen_subgraphs=set()defadd_subgraph(edges:list[Edge],prefix:str)->None:nonlocalmermaid_graphself_loop=len(edges)==1andedges[0].source==edges[0].targetifprefixandnotself_loop:subgraph=prefix.split(":")[-1]ifsubgraphinseen_subgraphs:msg=(f"Found duplicate subgraph '{subgraph}' -- this likely means that ""you're reusing a subgraph node with the same name. ""Please adjust your graph to have subgraph nodes with unique names.")raiseValueError(msg)seen_subgraphs.add(subgraph)mermaid_graph+=f"\tsubgraph {subgraph}\n"# Add nodes that belong to this subgraphifwith_stylesandprefixinsubgraph_nodes:forkey,nodeinsubgraph_nodes[prefix].items():mermaid_graph+=render_node(key,node)foredgeinedges:source,target=edge.source,edge.target# Add BR every wrap_label_n_words wordsifedge.dataisnotNone:edge_data=edge.datawords=str(edge_data).split()# Split the string into words# Group words into chunks of wrap_label_n_words sizeiflen(words)>wrap_label_n_words:edge_data=" <br> ".join(" ".join(words[i:i+wrap_label_n_words])foriinrange(0,len(words),wrap_label_n_words))ifedge.conditional:edge_label=f" -. {edge_data} .-> "else:edge_label=f" -- {edge_data} --> "else:edge_label=" -.-> "ifedge.conditionalelse" --> "mermaid_graph+=(f"\t{_escape_node_label(source)}{edge_label}"f"{_escape_node_label(target)};\n")# Recursively add nested subgraphsfornested_prefixinedge_groups:ifnotnested_prefix.startswith(prefix+":")ornested_prefix==prefix:continue# only go to first level subgraphsif":"innested_prefix[len(prefix)+1:]:continueadd_subgraph(edge_groups[nested_prefix],nested_prefix)ifprefixandnotself_loop:mermaid_graph+="\tend\n"# Start with the top-level edges (no common prefix)add_subgraph(edge_groups.get("",[]),"")# Add remaining subgraphs with edgesforprefixinedge_groups:if":"inprefixorprefix=="":continueadd_subgraph(edge_groups[prefix],prefix)seen_subgraphs.add(prefix)# Add empty subgraphs (subgraphs with no internal edges)ifwith_styles:forprefixinsubgraph_nodes:if":"notinprefixandprefixnotinseen_subgraphs:mermaid_graph+=f"\tsubgraph {prefix}\n"# Add nodes that belong to this subgraphforkey,nodeinsubgraph_nodes[prefix].items():mermaid_graph+=render_node(key,node)mermaid_graph+="\tend\n"seen_subgraphs.add(prefix)# Add custom styles for nodesifwith_styles:mermaid_graph+=_generate_mermaid_graph_styles(node_stylesorNodeStyles())returnmermaid_graph
def_escape_node_label(node_label:str)->str:"""Escapes the node label for Mermaid syntax."""returnre.sub(r"[^a-zA-Z-_0-9]","_",node_label)def_generate_mermaid_graph_styles(node_colors:NodeStyles)->str:"""Generates Mermaid graph styles for different node types."""styles=""forclass_name,styleinasdict(node_colors).items():styles+=f"\tclassDef {class_name}{style}\n"returnstyles
[docs]defdraw_mermaid_png(mermaid_syntax:str,output_file_path:Optional[str]=None,draw_method:MermaidDrawMethod=MermaidDrawMethod.API,background_color:Optional[str]="white",padding:int=10,)->bytes:"""Draws a Mermaid graph as PNG using provided syntax. Args: mermaid_syntax (str): Mermaid graph syntax. output_file_path (str, optional): Path to save the PNG image. Defaults to None. draw_method (MermaidDrawMethod, optional): Method to draw the graph. Defaults to MermaidDrawMethod.API. background_color (str, optional): Background color of the image. Defaults to "white". padding (int, optional): Padding around the image. Defaults to 10. Returns: bytes: PNG image bytes. Raises: ValueError: If an invalid draw method is provided. """ifdraw_method==MermaidDrawMethod.PYPPETEER:importasyncioimg_bytes=asyncio.run(_render_mermaid_using_pyppeteer(mermaid_syntax,output_file_path,background_color,padding))elifdraw_method==MermaidDrawMethod.API:img_bytes=_render_mermaid_using_api(mermaid_syntax,output_file_path,background_color)else:supported_methods=", ".join([m.valueforminMermaidDrawMethod])msg=(f"Invalid draw method: {draw_method}. "f"Supported draw methods are: {supported_methods}")raiseValueError(msg)returnimg_bytes
asyncdef_render_mermaid_using_pyppeteer(mermaid_syntax:str,output_file_path:Optional[str]=None,background_color:Optional[str]="white",padding:int=10,device_scale_factor:int=3,)->bytes:"""Renders Mermaid graph using Pyppeteer."""try:frompyppeteerimportlaunch# type: ignore[import]exceptImportErrorase:msg="Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`."raiseImportError(msg)fromebrowser=awaitlaunch()page=awaitbrowser.newPage()# Setup Mermaid JSawaitpage.goto("about:blank")awaitpage.addScriptTag({"url":"https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"})awaitpage.evaluate("""() => { mermaid.initialize({startOnLoad:true}); }""")# Render SVGsvg_code=awaitpage.evaluate("""(mermaidGraph) => { return mermaid.mermaidAPI.render('mermaid', mermaidGraph); }""",mermaid_syntax,)# Set the page background to whiteawaitpage.evaluate("""(svg, background_color) => { document.body.innerHTML = svg; document.body.style.background = background_color; }""",svg_code["svg"],background_color,)# Take a screenshotdimensions=awaitpage.evaluate("""() => { const svgElement = document.querySelector('svg'); const rect = svgElement.getBoundingClientRect(); return { width: rect.width, height: rect.height }; }""")awaitpage.setViewport({"width":int(dimensions["width"]+padding),"height":int(dimensions["height"]+padding),"deviceScaleFactor":device_scale_factor,})img_bytes=awaitpage.screenshot({"fullPage":False})awaitbrowser.close()ifoutput_file_pathisnotNone:awaitasyncio.get_event_loop().run_in_executor(None,Path(output_file_path).write_bytes,img_bytes)returnimg_bytesdef_render_mermaid_using_api(mermaid_syntax:str,output_file_path:Optional[str]=None,background_color:Optional[str]="white",file_type:Optional[Literal["jpeg","png","webp"]]="png",)->bytes:"""Renders Mermaid graph using the Mermaid.INK API."""try:importrequests# type: ignore[import]exceptImportErrorase:msg=("Install the `requests` module to use the Mermaid.INK API: ""`pip install requests`.")raiseImportError(msg)frome# Use Mermaid API to render the imagemermaid_syntax_encoded=base64.b64encode(mermaid_syntax.encode("utf8")).decode("ascii")# Check if the background color is a hexadecimal color code using regexifbackground_colorisnotNone:hex_color_pattern=re.compile(r"^#(?:[0-9a-fA-F]{3}){1,2}$")ifnothex_color_pattern.match(background_color):background_color=f"!{background_color}"image_url=(f"https://mermaid.ink/img/{mermaid_syntax_encoded}"f"?type={file_type}&bgColor={background_color}")response=requests.get(image_url,timeout=10)ifresponse.status_code==200:img_bytes=response.contentifoutput_file_pathisnotNone:Path(output_file_path).write_bytes(response.content)returnimg_byteselse:msg=(f"Failed to render the graph using the Mermaid.INK API. "f"Status code: {response.status_code}.")raiseValueError(msg)