matplotlib draws multi-tree dendrogram

Directory

  • code
  • Effect

Due to the needs of answering questions, etc., sometimes we need to visualize a tree structure to debug intuitively. Here I use
in Python
The matplotlib library implements a function for dendrogram visualization
plotTree().

This function receives the root node root of the tree and the tuple edges of the edges. It can only be a tree and cannot handle graphs with cycles.
This function can adaptively adjust the node distance according to the label length of different nodes. However, because the centering effect of the label is not very good, it is still not centered and evenly distributed when actually displayed. It needs to be in plt.text(x,y, The position of str(i)) should be fine-tuned according to the actual situation.

Code

from matplotlib import pyplot as plt
from collections import defaultdict

#core function
def plotTree(root:int,edges:list[list[int]],nodeGroups=[],groupLabels:list[str]=[],figureTitle="",labelOffset=0,heightDiff=5,fatherNodePosAdjustProportion=0.5,plotFormIdx =0,fieldAngle=180)->None:
    """
    Fixed parameters:
        root, the root node, the starting point of dfs, must have
        edges=[[a,b],[a2,b2],...], the undirected edges between nodes will be traversed from the root node in the program to determine which of a and b is the parent node.
    Optional parameters
        nodeGroups groups nodes, and different nodes are drawn in different colors.
        groupLabels is a description of each group of nodes
        figureTitle is the image title
        labelOffset is the distance between adjacent leaf nodes and must be a positive number, otherwise the numbers will be mixed together.
        heightDiff is the distance between different layers and must be a non-zero number. It supports positive and negative numbers to adjust the orientation of the tree.
        fatherNodePosAdjustProportion adjusts the x position of the parent node. When the value is 0.5, the effect is centered.
        plotFormIdx refers to the display form, which currently supports three types. 012 is the binary tree form, directory tree form, and sector chart form.
        fieldAngle is used for sector diagrams and receives the opening angle in the form of an angle.
    Precautions:
        When there is a loop in edges, the program will interrupt the last traversed edge in the loop according to the traversal order, resulting in unstable display results.
        Branches that ended very early still occupy empty spaces in later layers.
    """
    if len(nodeGroups)<1:
        nodeGroups=[i for j in edges for i in j]
        nodeGroups=[list(set(nodeGroups))]
    elif 'list' not in str(type(nodeGroups[0])):
        nodeGroups=[nodeGroups]#Unify into a multi-layer array
    if len(groupLabels)<len(nodeGroups):
        groupLabels + =[f"Class {<!-- -->i + 1}" for i in range(len(groupLabels),len(nodeGroups))]
    if heightDiff==0:
        print('There is a problem with the height and orientation setting, it has been initialized to 1')
        heightDiff=1
    childrens,ringEdges=getTreeAndRingEdgeFromGraph(root,edges)
    depth=depthOfTree(root,childrens)
    figureTitle + =f", depth: {<!-- -->depth}, number of nodes: {<!-- -->nodeNumsOfTree(nodeGroups)}"
    plotFormIdx=int(plotFormIdx)%3
    if plotFormIdx==0:
        pos,labelPos=binaryTreeForm(root,childrens,heightDiff,fatherNodePosAdjustProportion,labelOffset)
    elif plotFormIdx==1:
        pos,labelPos=ContentForm(root,childrens,heightDiff,fatherNodePosAdjustProportion)
    else:
        pos,labelPos=roundForm(root,childrens,heightDiff,fatherNodePosAdjustProportion,fieldAngle=fieldAngle)
    """
    The parameters required by the three modes are different.
    Binary tree mode requires parent node deviation proportion, height orientation, alternative: label x-axis additional deviation value, for
    Directory mode requires the parent node to deviate from the proportion and height orientation. Alternative: connect the line in front, change it to two segments, but it will affect the drawing, so give up temporarily.
    Sector mode requires parent node deviation proportion, height orientation, and total angle
    """
    edges=[[a,b] for a in childrens for b in childrens[a]]
    for a,b in edges:
        if plotFormIdx!=1:
            line=[pos[a],pos[b]]
        else:
            if b!=childrens[a][0]:
                line=[pos[childrens[a][0]],pos[b]]
            else:
                line=[pos[a],[pos[b][0],pos[a][1]]]
        line=list(zip(*line))
        if plotFormIdx!=1:#corresponds to binary tree structure and fan chart structure, labels and edges may intersect
            plt.plot(*line,color="lightgray")
        else:
            plt.plot(*line,color="gray")
    for idx,nodes in enumerate(nodeGroups):#Draw nodes
        if len(nodes)==0:
            continue
        points=[pos[i] for i in nodes]
        points=list(zip(*points))
        plt.scatter(*points,marker='o',label=groupLabels[idx])
    for i in labelPos:
        x,y=labelPos[i]
        plt.text(x,y,str(i))#number
    
    if plotFormIdx==1: #The label display is incomplete, add a point on the far right
        rx,ry=labelPos[root]
        for i in labelPos:
            rx=max(rx,labelPos[i][0] + len(str(i))/2)
        bgc=plt.rcParams['axes.facecolor']
        plt.scatter([rx],[ry],c=bgc)

    for a,b in ringEdges:
        line=[pos[a],pos[b]]
        line=list(zip(*line))
        plt.plot(*line,linestyle=":",color='orange')
    
    plt.subplots_adjust(left=0,bottom=0,right=0.95,top=0.95)#Leave some space for label numbers on the right side
    plt.axis('off')
    plt.legend()
    plt.title(figureTitle)
    plt.show()

def binaryTreeForm(root,childrens,heightDiff=1,fatherNodePosAdjustProportion=0.5,labelOffset=0):
    """
    Output the node position coordinates in the form of a binary tree. The height of the nodes in each layer is the same.
    First get the width of each subtree from bottom to top, and then set the specific position from top to bottom.
    labelOffset: additionally offset the label position proportionally to the length of the number
    """
    width={<!-- -->}#The width of the subtree represented by each node
    visited={<!-- -->}
    def dfsGetWidth(node):
        if node in visited:
            print(f"There is a ring: {<!-- -->node}, ignored")
            return 0
        visited[node]=True
        res=0#Do not consider spaces. When the image is enlarged, spaces will naturally appear
        for i in childrens[node]:
            res + =dfsGetWidth(i)
        res=max(res,len(str(node)))#Consider the length of the number and eliminate the overlapping problem caused by too long numbers
        width[node]=res
        return res
    dfsGetWidth(root)

    pos={<!-- -->}#Storage the drawing coordinates of each node
    visited={<!-- -->}
    def dfsGetPos(node,x0,y0):
        if node in visited:
            return
        visited[node]=True
        pos[node]=[x0 + width[node]*fatherNodePosAdjustProportion,y0]
        x=x0
        for i in childrens[node]:
            dfsGetPos(i,x,y0 + heightDiff)
            x + =width[i]
    dfsGetPos(root,0,0)
    labelPos={<!-- -->}
    for i in pos:
        x,y=pos[i]
        labelOffsetProportion=max(0,fatherNodePosAdjustProportion)
        labelOffsetProportion=min(1,labelOffsetProportion)#limited to between 0 and 1
        labelOffsetProportion*=-0.8
        labelOffsetProportion + =labelOffset# is used for proportional compensation. Linear compensation is to translate all nodes as a whole, which is useless.
        x + =labelOffsetProportion*len(str(i))#Try to center the label
        y + =abs(heightDiff)/6
        labelPos[i]=[x,y]
    return pos,labelPos

def ContentForm(root,childrens,heightDiff=1,fatherNodePosAdjustProportion=0):
    """
    Obtain the coordinates of the nodes in the tree in directory tree format, and obtain the maximum width of the numbers on each level horizontally.
    """
    maxLen=[]
    visited={<!-- -->}
    def dfsGetMaxLengthOfEachLevel(node,curlevel=1):
        if node in visited:
            print(f"There is a loop: {<!-- -->node}, ignored")
            return
        visited[node]=True
        if len(maxLen)<curlevel:
            maxLen.append(len(str(node)))
        elif len(str(node))>maxLen[curlevel-1]:
            maxLen[curlevel-1]=len(str(node))
        for i in childrens[node]:
            dfsGetMaxLengthOfEachLevel(i,curlevel + 1)
    dfsGetMaxLengthOfEachLevel(root,1)
    """
    First get the height of each subtree from bottom to top, and then set the specific position from top to bottom.
    """
    height={<!-- -->}#The width of the subtree represented by each node
    visited={<!-- -->}
    def dfsGetHeight(node):
        if node in visited:
            return 0
        visited[node]=True
        res=0
        for i in childrens[node]:
            res + =dfsGetHeight(i)
        if res==0:
            res=heightDiff
        height[node]=res
        return res
    dfsGetHeight(root)

    pos={<!-- -->}#Storage the drawing coordinates of each node
    visited={<!-- -->}
    def dfsGetPos(node,x0,y0,curlevel=1):
        if node in visited:
            return
        visited[node]=True
        pos[node]=[x0,y0 + fatherNodePosAdjustProportion*height[node]]
        x=x0 + maxLen[curlevel-1]
        y=y0
        for i in childrens[node]:
            dfsGetPos(i,x,y,curlevel + 1)
            y + =height[i]
    dfsGetPos(root,0,0)
    return pos,pos

def roundForm(root:int,childrens:dict[list[int]],heightDiff=1,fatherNodePosAdjustProportion=0.5,fieldAngle=120):
    """
    Draw a fan chart and output the position of the fan chart nodes and the position of the label
    fieldAngle: the opening angle of the sector diagram, input in angle form
    heightDiff, sector chart direction, up or down
    fatherNodePosAdjustProportion, the center offset ratio of the tree’s parent node;
    """
    from math import sin,cos,pi
    fieldAngle=(fieldAngle/2)*pi/180
    radian={<!-- -->}#The effect is equivalent to width
    def dfsGetRadian(node,level):#Radians at different levels should shrink proportionally, level is equivalent to radius
        res=0
        for i in childrens[node]:
            dfsGetRadian(i,level + 1)
            res + =radian[i]
        res=max(res*level/(level + 1),1)
        radian[node]=res
        return res
    dfsGetRadian(root,0)#The root node does not occupy the width
    pos={<!-- -->}
    
    def dfsGetPos(node,level,angleStart,angleEnd):
        angle=angleStart*(1-fatherNodePosAdjustProportion) + angleEnd*fatherNodePosAdjustProportion
        pos[node]=[level*cos(angle),level*sin(angle)]#The fan center is located at the origin, and the coordinates are obtained directly using trigonometric functions
        if radian[node] and level:
            radian[node]*=(level + 1)/level
        else:
            radian[node]=sum(radian[i] for i in childrens[node])
        angle=angleEnd-angleStart
        for i in childrens[node]:
            angleEnd=angleStart + angle*radian[i]/radian[node]
            dfsGetPos(i,level + 1,angleStart,angleEnd)
            angleStart=angleEnd
    if heightDiff>0:
        angleSt,angleEd=pi/2 + fieldAngle,pi/2-fieldAngle#Guarantee from left to right
    else:
        angleSt,angleEd=3*pi/2-fieldAngle,3*pi/2 + fieldAngle
    dfsGetPos(root,0,angleSt,angleEd)
    return pos,pos

def nodeNumsOfTree(nodeGroups):
    """
    Count the number of nodes in the tree
    """
    if 'list' not in str(type(nodeGroups[0])):
        return len(set(nodeGroups))
    res=set([])
    for i in nodeGroups:
        res|=set(i)
    return len(res)

def depthOfTree(root:int,childrens:dict[list[int]] | list[list[int]]):
    """
    Calculate the depth of the tree, dfs
    """
    visited={<!-- -->}
    def dfs(node):
        if node in visited:
            print(f"There is a loop: {<!-- -->node}, ignored")
            return 0
        visited[node]=True
        res=0
        for i in childrens[node]:
            res=max(res,dfs(i))
        return res + 1
    depth=dfs(root)
    return depth

def getTreeAndRingEdgeFromGraph(root:int,edges:list[list[int]])->tuple[dict[list[int]],list[list[int]]]:
    """
    Remove the loops in the graph and output the loops for subsequent use. The method is to start bfs from the root node root and eliminate the opposite edge in the ring
    Output children, not edges
    Regardless of whether the input neighbors are trees or multiple unconnected graphs, only the tree/graph where the root is located is detected.
    """
    edges0=[tuple(sorted(i)) for i in edges]
    dc={<!-- -->}
    edges=[]
    for i in edges0:
        if i not in dc:
            dc[i]=1
            edges.append(i)
    dc.clear()
    neighbors=defaultdict(list)
    for a,b in edges:
        neighbors[a].append(b)
        neighbors[b].append(a)
    
    childrens=defaultdict(list)
    ringsEdge=[]#The edge removed from the ring
    curNodes=[root]
    visited={<!-- -->}
    level=1
    while len(curNodes):
        nextNodes=[]
        for i in curNodes:
            visited[i]=level
            for j in neighbors[i]:
                if j in visited:
                    if visited[j]!=visited[i]-1:
                        ringsEdge.append([i,j])
                else:
                    childrens[i].append(j)
                    nextNodes.append(j)
        curNodes=nextNodes
        level + =1
    return childrens,ringsEdge

#Example, draw the 3x + 1 tree of the hail conjecture
def plot3xAdd1Tree(level=10):
    edges=[]
    curNode=[1]
    nodes=[1]
    for i in range(1,level):
        newNode=[2*j for j in curNode]
        newNode + =[(j-1)//3 for j in curNode if j%6==4 and j>4]#j is an even number modulo 3 remainder 1
        edges + =[[j,(j-1)//3] for j in curNode if j%6==4 and j>4]
        edges + =[[j,2*j] for j in curNode]
        curNode=newNode
        nodes + =curNode
    nodes=[[i for i in nodes if i & amp;1],[i for i in nodes if i & amp;1==0]]
    nodeLabel=["odd number","even number"]
    for i in range(3):#draw once in each of the 3 display forms
        plotTree(1,edges,nodes,figureTitle="3x + 1 tree",groupLabels=nodeLabel,heightDiff=-1,plotFormIdx=i,fatherNodePosAdjustProportion=0.5)

if __name__=="__main__":
    plot3xAdd1Tree(15)

Effect

Currently, three display forms are supported: binary tree, directory tree, and fan chart. The effects are as follows:
Binary tree format
Directory tree format

Sector chart format