[docs]defdraw_mermaid(nodes:Dict[str,Node],edges:List[Edge],*,first_node:Optional[str]=None,last_node:Optional[str]=None,with_styles:bool=True,curve_style:CurveStyle=CurveStyle.LINEAR,node_styles:Optional[NodeStyles]=None,wrap_label_n_words:int=9,)->str:"""Draws a Mermaid graph using the provided graph data. Args: nodes (dict[str, str]): List of node ids. edges (List[Edge]): List of edges, object with a source, target and data. first_node (str, optional): Id of the first node. Defaults to None. last_node (str, optional): Id of the last node. Defaults to None. with_styles (bool, optional): Whether to include styles in the graph. Defaults to True. curve_style (CurveStyle, optional): Curve style for the edges. Defaults to CurveStyle.LINEAR. node_styles (NodeStyles, optional): Node colors for different types. Defaults to NodeStyles(). wrap_label_n_words (int, optional): Words to wrap the edge labels. Defaults to 9. Returns: str: Mermaid graph syntax. """# Initialize Mermaid graph configurationmermaid_graph=((f"%%{{init: {{'flowchart': {{'curve': '{curve_style.value}'"f"}}}}}}%%\ngraph TD;\n")ifwith_styleselse"graph TD;\n")ifwith_styles:# Node formatting templatesdefault_class_label="default"format_dict={default_class_label:"{0}({1})"}iffirst_nodeisnotNone:format_dict[first_node]="{0}([{1}]):::first"iflast_nodeisnotNone:format_dict[last_node]="{0}([{1}]):::last"# Add nodes to the graphforkey,nodeinnodes.items():node_name=node.name.split(":")[-1]label=(f"<p>{node_name}</p>"ifnode_name.startswith(tuple(MARKDOWN_SPECIAL_CHARS))andnode_name.endswith(tuple(MARKDOWN_SPECIAL_CHARS))elsenode_name)ifnode.metadata:label=(f"{label}<hr/><small><em>"+"\n".join(f"{key} = {value}"forkey,valueinnode.metadata.items())+"</em></small>")node_label=format_dict.get(key,format_dict[default_class_label]).format(_escape_node_label(key),label)mermaid_graph+=f"\t{node_label}\n"# Group edges by their common prefixesedge_groups:Dict[str,List[Edge]]={}foredgeinedges:src_parts=edge.source.split(":")tgt_parts=edge.target.split(":")common_prefix=":".join(srcforsrc,tgtinzip(src_parts,tgt_parts)ifsrc==tgt)edge_groups.setdefault(common_prefix,[]).append(edge)seen_subgraphs=set()defadd_subgraph(edges:List[Edge],prefix:str)->None:nonlocalmermaid_graphself_loop=len(edges)==1andedges[0].source==edges[0].targetifprefixandnotself_loop:subgraph=prefix.split(":")[-1]ifsubgraphinseen_subgraphs:raiseValueError(f"Found duplicate subgraph '{subgraph}' -- this likely means that ""you're reusing a subgraph node with the same name. ""Please adjust your graph to have subgraph nodes with unique names.")seen_subgraphs.add(subgraph)mermaid_graph+=f"\tsubgraph {subgraph}\n"foredgeinedges:source,target=edge.source,edge.target# Add BR every wrap_label_n_words wordsifedge.dataisnotNone:edge_data=edge.datawords=str(edge_data).split()# Split the string into words# Group words into chunks of wrap_label_n_words sizeiflen(words)>wrap_label_n_words:edge_data=" <br> ".join(" ".join(words[i:i+wrap_label_n_words])foriinrange(0,len(words),wrap_label_n_words))ifedge.conditional:edge_label=f" -. {edge_data} .-> "else:edge_label=f" -- {edge_data} --> "else:ifedge.conditional:edge_label=" -.-> "else:edge_label=" --> "mermaid_graph+=(f"\t{_escape_node_label(source)}{edge_label}"f"{_escape_node_label(target)};\n")# Recursively add nested subgraphsfornested_prefixinedge_groups.keys():ifnotnested_prefix.startswith(prefix+":")ornested_prefix==prefix:continueadd_subgraph(edge_groups[nested_prefix],nested_prefix)ifprefixandnotself_loop:mermaid_graph+="\tend\n"# Start with the top-level edges (no common prefix)add_subgraph(edge_groups.get("",[]),"")# Add remaining subgraphsforprefixinedge_groups.keys():if":"inprefixorprefix=="":continueadd_subgraph(edge_groups[prefix],prefix)# Add custom styles for nodesifwith_styles:mermaid_graph+=_generate_mermaid_graph_styles(node_stylesorNodeStyles())returnmermaid_graph
def_escape_node_label(node_label:str)->str:"""Escapes the node label for Mermaid syntax."""returnre.sub(r"[^a-zA-Z-_0-9]","_",node_label)def_generate_mermaid_graph_styles(node_colors:NodeStyles)->str:"""Generates Mermaid graph styles for different node types."""styles=""forclass_name,styleinasdict(node_colors).items():styles+=f"\tclassDef {class_name}{style}\n"returnstyles
[docs]defdraw_mermaid_png(mermaid_syntax:str,output_file_path:Optional[str]=None,draw_method:MermaidDrawMethod=MermaidDrawMethod.API,background_color:Optional[str]="white",padding:int=10,)->bytes:"""Draws a Mermaid graph as PNG using provided syntax. Args: mermaid_syntax (str): Mermaid graph syntax. output_file_path (str, optional): Path to save the PNG image. Defaults to None. draw_method (MermaidDrawMethod, optional): Method to draw the graph. Defaults to MermaidDrawMethod.API. background_color (str, optional): Background color of the image. Defaults to "white". padding (int, optional): Padding around the image. Defaults to 10. Returns: bytes: PNG image bytes. Raises: ValueError: If an invalid draw method is provided. """ifdraw_method==MermaidDrawMethod.PYPPETEER:importasyncioimg_bytes=asyncio.run(_render_mermaid_using_pyppeteer(mermaid_syntax,output_file_path,background_color,padding))elifdraw_method==MermaidDrawMethod.API:img_bytes=_render_mermaid_using_api(mermaid_syntax,output_file_path,background_color)else:supported_methods=", ".join([m.valueforminMermaidDrawMethod])raiseValueError(f"Invalid draw method: {draw_method}. "f"Supported draw methods are: {supported_methods}")returnimg_bytes
asyncdef_render_mermaid_using_pyppeteer(mermaid_syntax:str,output_file_path:Optional[str]=None,background_color:Optional[str]="white",padding:int=10,device_scale_factor:int=3,)->bytes:"""Renders Mermaid graph using Pyppeteer."""try:frompyppeteerimportlaunch# type: ignore[import]exceptImportErrorase:raiseImportError("Install Pyppeteer to use the Pyppeteer method: `pip install pyppeteer`.")fromebrowser=awaitlaunch()page=awaitbrowser.newPage()# Setup Mermaid JSawaitpage.goto("about:blank")awaitpage.addScriptTag({"url":"https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"})awaitpage.evaluate("""() => { mermaid.initialize({startOnLoad:true}); }""")# Render SVGsvg_code=awaitpage.evaluate("""(mermaidGraph) => { return mermaid.mermaidAPI.render('mermaid', mermaidGraph); }""",mermaid_syntax,)# Set the page background to whiteawaitpage.evaluate("""(svg, background_color) => { document.body.innerHTML = svg; document.body.style.background = background_color; }""",svg_code["svg"],background_color,)# Take a screenshotdimensions=awaitpage.evaluate("""() => { const svgElement = document.querySelector('svg'); const rect = svgElement.getBoundingClientRect(); return { width: rect.width, height: rect.height }; }""")awaitpage.setViewport({"width":int(dimensions["width"]+padding),"height":int(dimensions["height"]+padding),"deviceScaleFactor":device_scale_factor,})img_bytes=awaitpage.screenshot({"fullPage":False})awaitbrowser.close()ifoutput_file_pathisnotNone:withopen(output_file_path,"wb")asfile:file.write(img_bytes)returnimg_bytesdef_render_mermaid_using_api(mermaid_syntax:str,output_file_path:Optional[str]=None,background_color:Optional[str]="white",)->bytes:"""Renders Mermaid graph using the Mermaid.INK API."""try:importrequests# type: ignore[import]exceptImportErrorase:raiseImportError("Install the `requests` module to use the Mermaid.INK API: ""`pip install requests`.")frome# Use Mermaid API to render the imagemermaid_syntax_encoded=base64.b64encode(mermaid_syntax.encode("utf8")).decode("ascii")# Check if the background color is a hexadecimal color code using regexifbackground_colorisnotNone:hex_color_pattern=re.compile(r"^#(?:[0-9a-fA-F]{3}){1,2}$")ifnothex_color_pattern.match(background_color):background_color=f"!{background_color}"image_url=(f"https://mermaid.ink/img/{mermaid_syntax_encoded}?bgColor={background_color}")response=requests.get(image_url)ifresponse.status_code==200:img_bytes=response.contentifoutput_file_pathisnotNone:withopen(output_file_path,"wb")asfile:file.write(response.content)returnimg_byteselse:raiseValueError(f"Failed to render the graph using the Mermaid.INK API. "f"Status code: {response.status_code}.")