Module omniplot.networkplot

Expand source code
import networkx as nx
import matplotlib.pyplot as plt
import igraph
from omniplot import igraph_classes
import numpy as np
from natsort import natsorted as nts
from matplotlib.lines import Line2D
import sys
import seaborn as sns
from typing import Union, List, Dict, Optional
from matplotlib.collections import PatchCollection
from matplotlib.patches import Rectangle
import pandas as pd
from omniplot.chipseq_utils import _calc_pearson
from omniplot.utils import _baumkuchen_xy, _save, _separate_data
from scipy.stats import zscore
from joblib import Parallel, delayed
from scipy.spatial.distance import pdist, squareform
import itertools as it
from datashader.bundling import hammer_bundle
import time
plt.rcParams['font.family']= 'sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial']
plt.rcParams['svg.fonttype'] = 'none'
sns.set_theme(font="Arial")
colormap_list: list=["nipy_spectral", "terrain","tab20b","tab20c","gist_rainbow","hsv","CMRmap","coolwarm","gnuplot","gist_stern","brg","rainbow","jet"]
hatch_list: list = ['//', '\\\\', '||', '--', '++', 'xx', 'oo', 'OO', '..', '**',
                    '/o', '\\|', '|*', '-\\', '+o', 'x*', 'o-', 'O|', 'O.', '*-',
                    'o\\','*\\','+\\','.\\','x\\',
                    '*/','./','x/','-/','+/']
maker_list: list=['.', '_' , '+','|', 'x', 'v', '^', '<', '>', 's', 'p', '*', 'h', 'D', 'd', 'P', 'X','o', '1', '2', '3', '4','|', '_']
linestyles = [
     ('solid', 'solid'),      # Same as (0, ()) or '-'
     ('dotted', 'dotted'),    # Same as (0, (1, 1)) or ':'
     ('dashed', 'dashed'),    # Same as '--'
     ('dashdot', 'dashdot'),  # Same as '-.'
    ('loosely dotted',        (0, (1, 10))),
     ('dotted',                (0, (1, 1))),
     ('densely dotted',        (0, (1, 1))),
     ('long dash with offset', (5, (10, 3))),
     ('loosely dashed',        (0, (5, 10))),
     ('dashed',                (0, (5, 5))),
     ('densely dashed',        (0, (5, 1))),

     ('loosely dashdotted',    (0, (3, 10, 1, 10))),
     ('dashdotted',            (0, (3, 5, 1, 5))),
     ('densely dashdotted',    (0, (3, 1, 1, 1))),

     ('dashdotdotted',         (0, (3, 5, 1, 5, 1, 5))),
     ('loosely dashdotdotted', (0, (3, 10, 1, 10, 1, 10))),
     ('densely dashdotdotted', (0, (3, 1, 1, 1, 1, 1)))]


def sankey_category(df, 
                    category: list=[], 
                    palette: str="tab20c", 
                    colormode: str="independent",
                    altcat: str="",
                    show_percentage=False,
                    show_percentage_target=False,
                    fontsize: int=12,
                    hatch: bool=False) -> plt.Axes:
    #plt.rcParams.update({'font.size': 12})
    """
    Drawing a sankey plot to compare multiple categories in a data. The usage example may be 
    to compare between clustering results.
    
    Parameters
    ----------
    df : pandas dataframe
        it has to contain categorical columns specified by the 'category' option.
    category: list
        List of column names to compare.
    palette: str, optional default("tab20c")
        Colormap name. Default: tab20c
    colormode: str ['shared', 'independent', 'alternative', 'trace'], optional default("independent")
        The way to color categories. 'independent' will give each category a distinct (unrelated) colorset. 
        'shared' will give a shared color if categories share the same names of labels. 'alternative' will
        color bars based on additional category specified by altcat. 'trace' will color all bars according
        to the first category.  
        
    altcat : str, optional (but required when colormode='alternative')
        

    Returns
    -------
    axis
    
    Raises
    ------
    Notes
    -----
    References
    ----------
    See Also
    --------
    Examples
    --------
    """
    if altcat !="":
        df = df.sort_values(category+[altcat])
    else:
        df = df.sort_values(category)
    df=df.reset_index(drop=True)
    
    #print(df)
    blockwidth=0.2
    xinterval=0.5
    space=0.02
    
    link_counts=[]
    catval=df[category].values
    catval=np.array(catval,dtype=str)
    for i in range(len(category)-1):
        s=category[i]
        t=category[i+1]
        sval=catval[:,i]
        tval=catval[:,i+1]
        links=[]
        for _sv, _tv in zip(sval,tval):
            links.append("-->>".join([str(_sv),str(_tv)]))
        ul, cl=np.unique(links, return_counts=True)
        link_counts.append([ul,cl])        
    
    heights={}
    xyh={}
    scolors={}
    for cat in category:
        u, c=np.unique(df[cat], return_counts=True)
        heights[cat]=[np.array(u,dtype=str),c]
    unique_cat=set()
    if colormode=="shared":
        for cat, v in heights.items():
            for c in list(v[0]):
                unique_cat.add(c)
        unique_cat=sorted(list(unique_cat))
        _tmp=plt.get_cmap(palette, len(unique_cat))
        _cmap={unique_cat[i]: _tmp(i) for i in range(len(unique_cat))}
        _hatchmap={unique_cat[i]: hatch_list[i] for i in range(len(unique_cat))}
        cmap={}
        hatchmap={}
        for i, cat in enumerate(category):
            cmap[cat]=_cmap
            hatchmap[cat]=_hatchmap
    elif colormode=="independent":
        cmap={}
        hatchmap={}
        for i, (cat, v) in enumerate(heights.items()):
            _tmp=plt.get_cmap(palette, v[0].shape[0])
            cmap[cat]={v[0][i]: _tmp(i) for i in range(v[0].shape[0])}
            hatchmap[cat]={v[0][i]: hatch_list[i] for i in range(v[0].shape[0])}

    elif colormode=="trace":
        altcat=category[0]
        altcat_list=list(df[altcat])
        altcat_unique=list(np.unique(altcat_list))
        _tmp=plt.get_cmap(palette, len(altcat_unique))

        altcat_dict={}

        hatchmap={}
        for i, a in enumerate(altcat_unique):
            altcat_dict[a]=_tmp(altcat_unique.index(a))
            hatchmap[a]=linestyles[i][1]

        altcat_colors=[]
        hatches=[]
        for a in altcat_list:
            altcat_colors.append(altcat_dict[a])
            hatches.append(hatchmap[a])


    elif colormode=="alternative":
        if altcat=="":
            raise Exception("If colormode is 'alternative', altcat must be specified")
        altcat_list=list(df[altcat])
        altcat_unique=list(np.unique(altcat_list))
        _tmp=plt.get_cmap(palette, len(altcat_unique))

        altcat_dict={}
        hatchmap={}
        for i, a in enumerate(altcat_unique):
            altcat_dict[a]=_tmp(altcat_unique.index(a))
            hatchmap[a]=linestyles[i][1]
            
        altcat_colors=[]
        hatches=[]
        for a in altcat_list:
            altcat_colors.append(altcat_dict[a])
            hatches.append(hatchmap[a])
    
    fig, ax=plt.subplots(figsize=[2+len(category),7])
    blocks=[]
    facecolors=[]
    facehatches=[]
    hs=[]
    for i,(cat, (u, ac)) in enumerate(heights.items()):
        xyh[cat]={}
        c=ac/np.sum(ac)
        h=0
        for _u, _c, _ac in zip(u,c,ac):
            blocks.append(Rectangle([i*xinterval, h],blockwidth,_c))
            if colormode=="alternative" or colormode=="trace":
                facecolors.append([1,1,1,0])
            else:
                facecolors.append(cmap[cat][_u])
                if hatch==True:
                    facehatches.append(hatchmap[cat][_u])
            h+=_c+space
            xyh[cat][str(_u)]=[i*xinterval, h, _c,_ac]
            ax.text(i*xinterval+blockwidth/2,h-space-_c/2, _u,
                    bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="b", lw=1, alpha=0.7))
        hs.append(h)
    if colormode=="alternative" or colormode=="trace":
        
        for i,(cat, (u, ac)) in enumerate(heights.items()):
            k=0
            for j, _u in enumerate(u):
                dfcat=np.array(df[cat], dtype=str)
                _df=df.loc[dfcat==_u]
                _df=_df.sort_values(altcat)                
                x, h, _c, _ac=xyh[cat][_u]
                if hatch==True:
                    for l, label in enumerate(_df[altcat]):
                        plt.plot([x,x+blockwidth],[h-_c-space+_c*l/_ac,h-_c-space+_c*l/_ac], 
                                 color=altcat_dict[label],
                                linestyle=hatchmap[label])
                        k+=1
                else:
                    for l, label in enumerate(_df[altcat]):
                        plt.plot([x,x+blockwidth],[h-_c-space+_c*l/_ac,h-_c-space+_c*l/_ac], color=altcat_dict[label])
                        k+=1
        if hatch==True:
            plt.legend([Line2D([0], [0], color=altcat_dict[label],linewidth=2,
                                linestyle=hatchmap[label]) for label in altcat_dict.keys()],
                   altcat_dict.keys(),loc=[1.01,0.9], fontsize=fontsize)
        else:
            plt.legend([Line2D([0], [0], color=altcat_dict[label]) for label in altcat_dict.keys()],
                   altcat_dict.keys(),loc=[1.01,0.9], fontsize=fontsize)
        plt.subplots_adjust(right=0.7)
        
    # draw links between categories
    resolution=100 # resolution of lines for the links
    for i, (ul, cl) in enumerate(link_counts):
        scat=category[i]
        tcat=category[i+1]
        sbottom={}
        tbottom={}
        scats=set()
        tcats=set()
        for _ul, _cl in zip(ul, cl):
            
            
            s, t=_ul.split("-->>")
            sx, sy, sh, _=xyh[scat][s]
            tx, ty, th, _=xyh[tcat][t]

            if not t in tbottom:
                tbottom[t]=ty-space-th
            if not s in sbottom:
                sbottom[s]=sy-space-sh
            _scl=sh*_cl/(heights[scat][1][heights[scat][0]==s])[0]
            _tcl=th*_cl/(heights[tcat][1][heights[tcat][0]==t])[0]
            
            # morphing links by convolution
            xconv=np.linspace(sx+blockwidth, tx,resolution-20*2+2)
            byconv=np.array((resolution//2) * [sbottom[s]] + (resolution//2) * [tbottom[t]])
            byconv = np.convolve(byconv, 0.05 * np.ones(20), mode='valid')
            byconv = np.convolve(byconv, 0.05 * np.ones(20), mode='valid')
            #ax.plot(xconv,yconv)
            tyconv=np.array((resolution//2) * [_scl+sbottom[s]] + (resolution//2) * [tbottom[t]+_tcl])
            tyconv = np.convolve(tyconv, 0.05 * np.ones(20), mode='valid')
            tyconv = np.convolve(tyconv, 0.05 * np.ones(20), mode='valid')
            
            #ax.plot(xconv,yconv)
            plt.fill_between(
                xconv, byconv, tyconv, alpha=0.65,
                color="b"
            )
            
            if show_percentage==True:
                plt.text(sx, sh/2+sy-space-sh,str(np.round(100*sh,1))+"%",ha="right",
                         va="center",
                         rotation=90,
                         bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="y", lw=1, alpha=0.8))
                if i==len(link_counts)-1:
                    plt.text(tx, th/2+ty-space-th,str(np.round(100*th,1))+"%",ha="right",
                             va="center",
                             rotation=90,
                             bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="y", lw=1, alpha=0.8))
            if show_percentage_target==True:
                plt.text(sx+blockwidth, _scl/2+sbottom[s],
                         str(np.round(100*_scl,1))+"%",
                         ha="left",va="center",
                         rotation=90,
                         bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="y", lw=1, alpha=0.8))
                
            sbottom[s]+=_scl
            tbottom[t]+=_tcl
            

    if hatch==True:
        for b, c, h in zip(blocks, facecolors, facehatches):
            pc = PatchCollection([b],facecolor=[c],edgecolor="black",linewidth=2, hatch=h)
            ax.add_collection(pc)
    else:
        pc = PatchCollection(blocks,facecolor=facecolors,edgecolor="black",linewidth=2)
        ax.add_collection(pc)

    
    
    
    
    plt.xlim(-blockwidth*0.5, xinterval*(len(heights)-1)+blockwidth*1.5)
    plt.ylim(-0.01, np.amax(hs)*1.1)
    plt.xticks([i*xinterval+blockwidth/2 for i in range(len(category))],category, rotation=90, 
               fontsize=fontsize)
    plt.yticks([])
    plt.subplots_adjust(bottom=0.2)
    return ax


def pienodes(g: igraph.Graph,  
             node_features: Union[pd.DataFrame, Dict],
             vertex_label: list=[], 
             pie_frac: str="frac",
             pie_label: str="label",
             pie_palette: Union[str , dict]="tab20c", 
             node_label: str="all",
             piesize: Optional[float]=None,
             label_color: str="black",
             figsize: list=[],
             save: str="",
             **kwargs) -> Dict: 
    """
    Drawing a network whose noses are pie charts.
    
    Parameters
    ----------
    g : igraph object
    vertex_label: list
        The list of node labels.
        e.g.: nodes=["A","B","C","D","E"]
    node_features: dict
        A dictionary containing fractions and labels of the pie charts.
        e.g.:
        pie_features={"A":{"frac":np.array([50,50]),"label":np.array(["a","b"])},
                  "B":{"frac":np.array([90,5,5]),"label":np.array(["a","b","c"])},
                  "C":{"frac":np.array([100]),"label":np.array(["c"])},
                  "D":{"frac":np.array([100]),"label":np.array(["b"])},
                  "E":{"frac":np.array([100]),"label":np.array(["a"])}}
    pie_frac : str
        The key value for the fractions of the pie charts. Default: "frac" (as the example of the above pie_features).
    pie_label : str
        The key value for the labels of the pie charts. Default: "label" (as the example of the above pie_features). 
    pie_palette: str or dict
        If string is provided, it must be one of the matplotlib colormap names for the pie charts. If dict, then  
    node_label: str
        Whether to label all nodes or not. If partial, labels show up only for 0.05 upper quantile of nodes with a high degree.
    
    piesize : float
        Scaling pie chart sizes if they are too large/small.    

    Returns
    -------
    axis
    
    Raises
    ------
    Notes
    -----
    References
    ----------
    See Also
    --------
    Examples
    --------
    """#print(kwargs)    
    
    
    
    if type(pie_palette)== str:
        colors={}
        if type(node_features)==dict:
            
            unique_labels=set()
            for k, v in node_features.items():
                for la in v[pie_label]:
                    unique_labels.add(la)
            unique_labels=list(unique_labels)
        else:
            node_features=node_features.pivot_table(index="node",columns=pie_label)
            
            unique_labels=[nf[1] for nf in node_features.columns]
            
        cmap=plt.get_cmap(pie_palette)
        unique_labels=nts(unique_labels)
        labelnum=len(unique_labels)
        for i, ul in enumerate(unique_labels):
            colors[ul]=cmap(i/labelnum)
    elif type(pie_palette)==dict:
        colors=pie_palette
        unique_labels=nts(colors.keys())
    else:
        raise Exception("Unknown pie_palette type.")
    
    pos=np.array(kwargs["layout"].coords)
    posmax=np.amax(pos)
    if len(figsize)==0:
        figsize=[5,5*np.abs(np.amax(pos[:,1]))/np.abs(np.amax(pos[:,0]))]
    fig, ax = plt.subplots(figsize=figsize)
    plt.subplots_adjust(right=0.8)
    mgd=igraph_classes.MatplotlibGraphDrawer(ax)
    
    _deg=np.array([d for d in g.degree()])
    _deg=np.log(_deg+2)
    _deg=0.05*posmax*(_deg/_deg.max())
    #print(_deg)
    degsort=np.argsort(_deg)
    nodes=np.array(vertex_label)
    deg=_deg[degsort]
    nodes=nodes[degsort]
    pos=np.array(kwargs["layout"].coords)[degsort]
    if piesize!=None:
        mgd.draw(g, vertex_size=piesize,alpha=0,**kwargs)
    else:
        mgd.draw(g, vertex_size=1.8*_deg,alpha=0,**kwargs)
    #piesize=0.02
    #print(nodes)
    #print(deg)
    xycoord=[]
    text_pos=[]
    for i, n in enumerate(nodes):
        
        
        x, y=pos[i]
        if type(node_features)==dict:
            frac=node_features[n][pie_frac]
            frac=2*np.pi*np.array(frac)/np.sum(frac)
            
            _colors=[colors[f] for f in node_features[n][pie_label]]
        else:
            _node_features=node_features.loc[n]
            frac=list(_node_features.loc[pie_frac,unique_labels])
            frac=2*np.pi*np.array(frac)/np.sum(frac)
            _colors=[colors[f] for f in unique_labels]
            
        angle=0
        for fr, co in zip(frac, _colors):
            if piesize!=None:
                _baumkuchen_xy(ax, x, y, angle, fr, 0, piesize,20, co)
            else:
                _baumkuchen_xy(ax, x, y, angle, fr, 0, deg[i],20, co)
            angle+=fr
        # xx,yy=trans(pos[i]) # figure coordinates
        # #print(xx,yy, pos[i])
        # xa,ya=trans2((xx,yy)) # axes coordinates
        # a = plt.axes([xa-deg[i]/2,ya-deg[i]/2, deg[i], deg[i]],
        #              rasterized=True,
        #              adjustable="datalim")
        # #a.set_zorder(1)
        # a.set_aspect('equal')
        #
        # a.pie(node_features[n][pie_frac], colors=[colors[f] for f in node_features[n][pie_label]])
        # a.margins(0,0)
        
        if node_label=="all":
            text_pos.append([x, y,n])
        elif node_label=="partial":
            if deg[i] > np.quantile(deg, 0.95):
                text_pos.append([x, y,n])
        elif node_label=="none":
            pass

    for j, (xa,ya,n) in enumerate(text_pos):
        ax.text(xa,ya,n, color=label_color)
        ax.zorder=i+j
    legend_elements = [Line2D([0], [0], marker='o', color='lavender', label=ul,markerfacecolor=colors[ul], markersize=10)
                      for ul in unique_labels]
    
    ax.legend(handles=legend_elements,bbox_to_anchor=(0.95, 1))
    _save(save, "pienodes")
    return {"axes":ax}


def sankey_flow(df):
    
    pass

def _bundlle_edges(G, pos):
    #print(G.nodes)
    nodes_to_int={}
    i=0
    for n in G.nodes:
        nodes_to_int[n]=i
        i+=1
    
    
    nodes_py = [[nodes_to_int[name], pos[name][0], pos[name][1]] for name in G.nodes]
    ds_nodes = pd.DataFrame(nodes_py, columns=['name', 'x', 'y'])       
    
    ds_edges_py = [[nodes_to_int[n0], nodes_to_int[n1]] for (n0, n1) in G.edges]
    ds_edges = pd.DataFrame(ds_edges_py, columns=['source', 'target'])
    
    hb = hammer_bundle(ds_nodes, ds_edges)
    return hb
    

def correlation(df: pd.DataFrame, 
                variables: list=[],
                category: Union[str, list]=[],
                method="pearson",
                layout: str="spring_layout",
                palette: str="tab20c",
                clustering: str="louvain",
                figsize: list=[],
                ztransform: bool=True,
                threshold: Union[float, str]=0.5,
                layout_param: dict={},
                node_edge_color: str="black",
                edge_color: str="weight",
                edge_cmap: str="hot",
                edge_width: Union[str, float]="weight",
                node_size: float=50,
                node_alpha: float=0.85,
                linewidths: float=0.5,
                n_jobs: int=-1,
                edges_alpha: float=0.7,
                edge_width_scaling: float=4,
                rows_cols: list=[],
                node_color="b",
                bundle: bool=True,
                show_edges: bool=True,
                save: str="",
                clustering_param={}) -> Dict:
    """
    Drawing a network based on correlations or distances between observations.
    Parameters
    
    ----------
    df : pandas DataFrame
    variables: list, optional (default: [])
        the names of values to calculate correlations
         
    category: str or list, optional (default: [])
        the names of categorical values to display as color labels
    method: str, optional (default: "pearson")
        method for correlation/distance calculation. Availables: "pearson", "euclidean", "cosine",
        if a distance method is chosen, sigmoidal function is used to convert distances into edge weights.
    layout: str, optional
        Networkx layouts include: pydot_layout, spring_layout, random_layout, circular_layout and so on. Please see https://networkx.org/documentation/stable/reference/drawing.html
    palette : str, optional (default: "tab20c")
        A colormap name.
    clustering: str, optional (default: "louvain")
        Networkx clustering methods include:  "louvain", "greedy_modularity", "kernighan_lin_bisection", "asyn_lpa_communities","asyn_fluidc"
    figsize : List[int], optional
        The figure size, e.g., [4, 6].
    ztransform : bool, optional
        Whether to transform values to z-score
    threshold : float, optional (default: 0.5)
        A cutoff value to remove edges of which correlations/distances are less than this value
    layout_param: dict, optional
        Networkx layout parameters related to the layout option
    node_edge_color: str, optional (default: "black")
        The colors of node edges.
        
    edge_color: str, optional (default: "weight")
        The color of edges. The default will color edges based on the edge weights calculated based on pearson/distance methods.
    edge_cmap: str, optional (default: "hot")
        
    edge_width: Union[str, float]="weight",
    node_size: float=50,
    node_alpha: float=0.85,
    linewidths: float=0.5,
    n_jobs: int=-1,
    edges_alpha: float=0.7,
    edge_width_scaling: float=4,
    rows_cols: list=[],
    node_color="b",
    bundle: bool=True,
    show_edges: bool=True
    
    Returns
    -------
    dict
    
    Raises
    ------
    Notes
    -----
    References
    ----------
    See Also
    --------
    Examples
    --------
    """
    clustering_options=["","louvain", "greedy_modularity", "kernighan_lin_bisection", "asyn_lpa_communities","asyn_fluidc"]
    if not clustering in clustering_options:
        raise Exception("Available clustering methods are "+", ".join(clustering_options))
    
    
    
    original_index=list(df.index)
    X, category=_separate_data(df, variables=variables, category=category)
    if ztransform==True:
        X=zscore(X, axis=0)
    if method=="pearson":
        starttime=time.time()
        # dmat=Parallel(n_jobs=n_jobs)(delayed(_calc_pearson)(ind, X) for ind in list(it.combinations(range(X.shape[0]), 2)))
        # dmat=np.array(dmat)
        # dmat=squareform(dmat)
        # #print(dmat)
        # dmat+=np.identity(dmat.shape[0])
        
        dmat=np.corrcoef(X)
        print("correlation calc: ", time.time()-starttime)
    else:
        dmat=squareform(pdist(X, method))
        dmat=(dmat-np.mean(dmat))/np.std(dmat)
        #dmat=dmat/np.amax(dmat)
        dmat=(1+np.exp(-dmat))**-1
        #dmat=1/(1+dmat)
    G = nx.Graph()
    for index in original_index:
        G.add_node(index)
    for i, j in list(it.combinations(range(dmat.shape[0]), 2)):
        if dmat[i,j]>=threshold:
            G.add_edge(original_index[i], original_index[j], weight=dmat[i,j])
    
    if clustering =="louvain": 
        try:
            import community
        except ImportError as e:
            raise Exception("can not import community. Try 'pip install python-louvain'")
        comm=community.best_partition(G, **clustering_param)
    
    elif clustering =="greedy_modularity":
        from networkx.algorithms import community
        comm=community.greedy_modularity_communities(G, **clustering_param)
        comm=list(comm)
    elif clustering =="kernighan_lin_bisection":
        from networkx.algorithms import community
        comm=community.kernighan_lin_bisection(G, **clustering_param)
        comm=list(comm)
    elif clustering =="asyn_lpa_communities":
        from networkx.algorithms import community
        comm=community.asyn_lpa_communities(G, **clustering_param)
        comm=list(comm)
    elif clustering =="asyn_fluidc":
        from networkx.algorithms import community
        comm=community.asyn_lpa_communities(G, **clustering_param)
        comm=list(comm)
    elif clustering=="":
        comm=[]
        #print(comm)
    weights=[]
    for s, t, w in G.edges(data=True):
        weights.append(w['weight'])
    weights=np.array(weights)
    
    #pos = nx.spring_layout(G, weight = 'weight', **spring_layout_param)
    layoutfunction = getattr(nx, layout)
    if layout=="spring_layout" and len(layout_param)==0:
        layout_param=dict(k=0.75,
                seed=0,
                scale=1,weight = 'weight')
        pos = layoutfunction(G,  **layout_param)
    else:
        pos = layoutfunction(G, **layout_param)
    #print(pos)
    colors={}
    colorlut={}
    for cat in category:
        _cats=df[cat]
        u=list(set(_cats))
        _cmp=plt.get_cmap(palette, len(u))
        _cmap_dict={k: _cmp(i) for i, k in enumerate(u)}
        colorlut[cat]=_cmap_dict
        colors[cat]=[]
        for g in G.nodes:
            colors[cat].append(_cmap_dict[_cats[original_index.index(g)]])
            
    if clustering!="":
        if clustering =="louvain":
            u=set()
            for k, v in comm.items():
                u.add(v)
            _cmp=plt.get_cmap(palette, len(u))
            _cmap_dict={k: _cmp(i) for i, k in enumerate(u)}
            colorlut[clustering]=_cmap_dict
            colors[clustering]=[]
            for g in G.nodes:
                colors[clustering].append(_cmap_dict[comm[g]])
        else:
            _cmp=plt.get_cmap(palette, len(comm))
            _cmap_dict={i: _cmp(i) for i in range(len(comm))}
            colorlut[clustering]=_cmap_dict
            colors[clustering]=[]
            for g in G.nodes:
                for i, com in enumerate(comm):
                    if g in com:
                        colors[clustering].append(_cmap_dict[i])
    
    if edge_color=="weight":
        edge_color=weights
    if edge_width=="weight":
        edge_width=edge_width_scaling*weights


    if len(category)==0 and clustering=="":
        fig, ax=plt.subplots()
        nx.draw_networkx_nodes(G = G, node_color=node_color, 
                                   node_size=node_size,
                                   pos=pos,linewidths=linewidths, 
                                   ax=ax,edgecolors=node_edge_color, 
                                   alpha=node_alpha)
        nx.draw_networkx_edges(G = G,pos=pos, edge_color=edge_color, edge_cmap=plt.get_cmap(edge_cmap),
                               alpha=edges_alpha, 
                               width=edge_width, ax=ax)
    else:
        if len(figsize)==0:
            figsize=[4*(len(cat)+int(clustering!="")),4]
        if len(rows_cols)==0:
            rows_cols=[1, len(cat)+int(clustering!="")]
        fig, axes=plt.subplots(figsize=figsize, ncols=rows_cols[1], nrows=rows_cols[0])
        
        axes=axes.flatten()
        
        if bundle==True:
            hb=_bundlle_edges(G, pos)
        
        
        for ax, cat in zip(axes, category):
            #nx.draw(G=G,pos=pos, font_size=8,linewidths=0)
            nx.draw_networkx_nodes(G = G, node_color=colors[cat], 
                                   node_size=node_size,
                                   pos=pos,linewidths=linewidths, 
                                   ax=ax,edgecolors=node_edge_color, 
                                   alpha=node_alpha)
            if show_edges==True:
                nx.draw_networkx_edges(G = G,pos=pos, edge_color=edge_color, edge_cmap=plt.get_cmap(edge_cmap),
                                   alpha=edges_alpha, 
                                   width=edge_width, ax=ax)
                                    #,arrows=True,connectionstyle="arc3,rad=0.3")
            ax.set_title(method+", colored by "+cat)
            legendhandles=[]
            for label, color in colorlut[cat].items():
                legendhandles.append(Line2D([0], [0], color=color,linewidth=5, label=label))
            #g.add_legend(legend_data=legendhandles,title="Aroma",label_order=["W","F","Y"])
            ax.legend(handles=legendhandles, loc='best', title=cat)
            if bundle==True:
                ax.plot(hb.x, hb.y, "y", zorder=1, linewidth=3)
        if clustering!="":
            nx.draw_networkx_nodes(G = G, node_color=colors[clustering], 
                                   node_size=node_size,
                                   pos=pos,linewidths=linewidths, 
                                   ax=axes[-1],edgecolors=node_edge_color, 
                                   alpha=node_alpha)
            if show_edges==True:
                nx.draw_networkx_edges(G = G,pos=pos, edge_color=edge_color, edge_cmap=plt.get_cmap(edge_cmap),
                                   alpha=edges_alpha, 
                                   width=edge_width, ax=axes[-1])
                                    #,arrows=True,connectionstyle="arc3,rad=0.3")
            axes[-1].set_title(method+", colored by "+clustering)
            legendhandles=[]
            for label, color in colorlut[clustering].items():
                legendhandles.append(Line2D([0], [0], color=color,linewidth=5, label=label))
            axes[-1].legend(handles=legendhandles, loc='best', title=clustering)
            if bundle==True:
                axes[-1].plot(hb.x, hb.y, "y", zorder=1, linewidth=3)
    _save(save, "network_correlation")
    return {"axes":axes,"networkx":G, "distance_mat":dmat}


if __name__=="__main__":
    
    test="sankey_category"
    test="pienode"
    test="correlation"
    if test=="correlation":
        df=sns.load_dataset("penguins")
        df=df.dropna(axis=0)
        df=df.reset_index()
        correlation(df, category=["species", "island","sex"], 
                    method="pearson", 
                    ztransform=True,
                    clustering ="asyn_fluidc",show_edges=True, bundle=False)
        plt.show()
        
    elif test=="sankey_category":
        df=pd.read_csv("../data/kmeans_result.csv")
        sankey_category(df, ["kmeans2","kmeans3","sex"],
                        colormode="alternative",
                        altcat="species",
                        show_percentage=False,
                        show_percentage_target=False)
        plt.show()
    elif test=="pienode":
        edges=[[0,0],[0,1],[0,2],[2,1],[2,3],[3,4],[0,5]]
        edge_width=[1 for i in range(len(edges))]
        nodes=["A","B","C","D","E","F"]
        pie_features={"A":{"frac":np.array([50,50]),"label":np.array(["a","b"])},
                      "B":{"frac":np.array([90,5,5]),"label":np.array(["a","b","c"])},
                      "C":{"frac":np.array([100]),"label":np.array(["c"])},
                      "D":{"frac":np.array([100]),"label":np.array(["b"])},
                      "E":{"frac":np.array([100]),"label":np.array(["a"])},
                      "F":{"frac":np.array([10,20,30]),"label":np.array(["a","b","c"])}}
        
        g=igraph.Graph(edges=edges)
        layout = g.layout("fr")
        fraclist=[100,50,20,0]
        labels=["a","b","c","d"]
        pie_features={"node":[],"label":[],"frac":[]}
        for n in nodes:
            tmp=np.random.choice(fraclist,4)
            for l, f in zip(labels, tmp):
                pie_features["node"].append(n)
                pie_features["label"].append(l)
                pie_features["frac"].append(f)
                
        pie_features=pd.DataFrame(pie_features)
        pienodes(g, 
                 vertex_label=nodes,
                 node_features=pie_features,
                 piesize=0.1,
                 layout=layout,
        vertex_color="lightblue",
        edge_color="gray",
        edge_arrow_size=0.03,
        edge_width=edge_width,
        keep_aspect_ratio=True)
        plt.show()
        

Functions

def correlation(df: pandas.core.frame.DataFrame, variables: list = [], category: Union[str, list] = [], method='pearson', layout: str = 'spring_layout', palette: str = 'tab20c', clustering: str = 'louvain', figsize: list = [], ztransform: bool = True, threshold: Union[float, str] = 0.5, layout_param: dict = {}, node_edge_color: str = 'black', edge_color: str = 'weight', edge_cmap: str = 'hot', edge_width: Union[float, str] = 'weight', node_size: float = 50, node_alpha: float = 0.85, linewidths: float = 0.5, n_jobs: int = -1, edges_alpha: float = 0.7, edge_width_scaling: float = 4, rows_cols: list = [], node_color='b', bundle: bool = True, show_edges: bool = True, save: str = '', clustering_param={}) ‑> Dict[~KT, ~VT]

Drawing a network based on correlations or distances between observations. Parameters


df : pandas DataFrame variables: list, optional (default: []) the names of values to calculate correlations

category: str or list, optional (default: []) the names of categorical values to display as color labels method: str, optional (default: "pearson") method for correlation/distance calculation. Availables: "pearson", "euclidean", "cosine", if a distance method is chosen, sigmoidal function is used to convert distances into edge weights. layout: str, optional Networkx layouts include: pydot_layout, spring_layout, random_layout, circular_layout and so on. Please see https://networkx.org/documentation/stable/reference/drawing.html palette : str, optional (default: "tab20c") A colormap name. clustering: str, optional (default: "louvain") Networkx clustering methods include: "louvain", "greedy_modularity", "kernighan_lin_bisection", "asyn_lpa_communities","asyn_fluidc" figsize : List[int], optional The figure size, e.g., [4, 6]. ztransform : bool, optional Whether to transform values to z-score threshold : float, optional (default: 0.5) A cutoff value to remove edges of which correlations/distances are less than this value layout_param: dict, optional Networkx layout parameters related to the layout option node_edge_color: str, optional (default: "black") The colors of node edges.

edge_color: str, optional (default: "weight") The color of edges. The default will color edges based on the edge weights calculated based on pearson/distance methods. edge_cmap: str, optional (default: "hot")

edge_width: Union[str, float]="weight", node_size: float=50, node_alpha: float=0.85, linewidths: float=0.5, n_jobs: int=-1, edges_alpha: float=0.7, edge_width_scaling: float=4, rows_cols: list=[], node_color="b", bundle: bool=True, show_edges: bool=True

Returns

dict
 

Raises

Notes

References

See Also

Examples

Expand source code
def correlation(df: pd.DataFrame, 
                variables: list=[],
                category: Union[str, list]=[],
                method="pearson",
                layout: str="spring_layout",
                palette: str="tab20c",
                clustering: str="louvain",
                figsize: list=[],
                ztransform: bool=True,
                threshold: Union[float, str]=0.5,
                layout_param: dict={},
                node_edge_color: str="black",
                edge_color: str="weight",
                edge_cmap: str="hot",
                edge_width: Union[str, float]="weight",
                node_size: float=50,
                node_alpha: float=0.85,
                linewidths: float=0.5,
                n_jobs: int=-1,
                edges_alpha: float=0.7,
                edge_width_scaling: float=4,
                rows_cols: list=[],
                node_color="b",
                bundle: bool=True,
                show_edges: bool=True,
                save: str="",
                clustering_param={}) -> Dict:
    """
    Drawing a network based on correlations or distances between observations.
    Parameters
    
    ----------
    df : pandas DataFrame
    variables: list, optional (default: [])
        the names of values to calculate correlations
         
    category: str or list, optional (default: [])
        the names of categorical values to display as color labels
    method: str, optional (default: "pearson")
        method for correlation/distance calculation. Availables: "pearson", "euclidean", "cosine",
        if a distance method is chosen, sigmoidal function is used to convert distances into edge weights.
    layout: str, optional
        Networkx layouts include: pydot_layout, spring_layout, random_layout, circular_layout and so on. Please see https://networkx.org/documentation/stable/reference/drawing.html
    palette : str, optional (default: "tab20c")
        A colormap name.
    clustering: str, optional (default: "louvain")
        Networkx clustering methods include:  "louvain", "greedy_modularity", "kernighan_lin_bisection", "asyn_lpa_communities","asyn_fluidc"
    figsize : List[int], optional
        The figure size, e.g., [4, 6].
    ztransform : bool, optional
        Whether to transform values to z-score
    threshold : float, optional (default: 0.5)
        A cutoff value to remove edges of which correlations/distances are less than this value
    layout_param: dict, optional
        Networkx layout parameters related to the layout option
    node_edge_color: str, optional (default: "black")
        The colors of node edges.
        
    edge_color: str, optional (default: "weight")
        The color of edges. The default will color edges based on the edge weights calculated based on pearson/distance methods.
    edge_cmap: str, optional (default: "hot")
        
    edge_width: Union[str, float]="weight",
    node_size: float=50,
    node_alpha: float=0.85,
    linewidths: float=0.5,
    n_jobs: int=-1,
    edges_alpha: float=0.7,
    edge_width_scaling: float=4,
    rows_cols: list=[],
    node_color="b",
    bundle: bool=True,
    show_edges: bool=True
    
    Returns
    -------
    dict
    
    Raises
    ------
    Notes
    -----
    References
    ----------
    See Also
    --------
    Examples
    --------
    """
    clustering_options=["","louvain", "greedy_modularity", "kernighan_lin_bisection", "asyn_lpa_communities","asyn_fluidc"]
    if not clustering in clustering_options:
        raise Exception("Available clustering methods are "+", ".join(clustering_options))
    
    
    
    original_index=list(df.index)
    X, category=_separate_data(df, variables=variables, category=category)
    if ztransform==True:
        X=zscore(X, axis=0)
    if method=="pearson":
        starttime=time.time()
        # dmat=Parallel(n_jobs=n_jobs)(delayed(_calc_pearson)(ind, X) for ind in list(it.combinations(range(X.shape[0]), 2)))
        # dmat=np.array(dmat)
        # dmat=squareform(dmat)
        # #print(dmat)
        # dmat+=np.identity(dmat.shape[0])
        
        dmat=np.corrcoef(X)
        print("correlation calc: ", time.time()-starttime)
    else:
        dmat=squareform(pdist(X, method))
        dmat=(dmat-np.mean(dmat))/np.std(dmat)
        #dmat=dmat/np.amax(dmat)
        dmat=(1+np.exp(-dmat))**-1
        #dmat=1/(1+dmat)
    G = nx.Graph()
    for index in original_index:
        G.add_node(index)
    for i, j in list(it.combinations(range(dmat.shape[0]), 2)):
        if dmat[i,j]>=threshold:
            G.add_edge(original_index[i], original_index[j], weight=dmat[i,j])
    
    if clustering =="louvain": 
        try:
            import community
        except ImportError as e:
            raise Exception("can not import community. Try 'pip install python-louvain'")
        comm=community.best_partition(G, **clustering_param)
    
    elif clustering =="greedy_modularity":
        from networkx.algorithms import community
        comm=community.greedy_modularity_communities(G, **clustering_param)
        comm=list(comm)
    elif clustering =="kernighan_lin_bisection":
        from networkx.algorithms import community
        comm=community.kernighan_lin_bisection(G, **clustering_param)
        comm=list(comm)
    elif clustering =="asyn_lpa_communities":
        from networkx.algorithms import community
        comm=community.asyn_lpa_communities(G, **clustering_param)
        comm=list(comm)
    elif clustering =="asyn_fluidc":
        from networkx.algorithms import community
        comm=community.asyn_lpa_communities(G, **clustering_param)
        comm=list(comm)
    elif clustering=="":
        comm=[]
        #print(comm)
    weights=[]
    for s, t, w in G.edges(data=True):
        weights.append(w['weight'])
    weights=np.array(weights)
    
    #pos = nx.spring_layout(G, weight = 'weight', **spring_layout_param)
    layoutfunction = getattr(nx, layout)
    if layout=="spring_layout" and len(layout_param)==0:
        layout_param=dict(k=0.75,
                seed=0,
                scale=1,weight = 'weight')
        pos = layoutfunction(G,  **layout_param)
    else:
        pos = layoutfunction(G, **layout_param)
    #print(pos)
    colors={}
    colorlut={}
    for cat in category:
        _cats=df[cat]
        u=list(set(_cats))
        _cmp=plt.get_cmap(palette, len(u))
        _cmap_dict={k: _cmp(i) for i, k in enumerate(u)}
        colorlut[cat]=_cmap_dict
        colors[cat]=[]
        for g in G.nodes:
            colors[cat].append(_cmap_dict[_cats[original_index.index(g)]])
            
    if clustering!="":
        if clustering =="louvain":
            u=set()
            for k, v in comm.items():
                u.add(v)
            _cmp=plt.get_cmap(palette, len(u))
            _cmap_dict={k: _cmp(i) for i, k in enumerate(u)}
            colorlut[clustering]=_cmap_dict
            colors[clustering]=[]
            for g in G.nodes:
                colors[clustering].append(_cmap_dict[comm[g]])
        else:
            _cmp=plt.get_cmap(palette, len(comm))
            _cmap_dict={i: _cmp(i) for i in range(len(comm))}
            colorlut[clustering]=_cmap_dict
            colors[clustering]=[]
            for g in G.nodes:
                for i, com in enumerate(comm):
                    if g in com:
                        colors[clustering].append(_cmap_dict[i])
    
    if edge_color=="weight":
        edge_color=weights
    if edge_width=="weight":
        edge_width=edge_width_scaling*weights


    if len(category)==0 and clustering=="":
        fig, ax=plt.subplots()
        nx.draw_networkx_nodes(G = G, node_color=node_color, 
                                   node_size=node_size,
                                   pos=pos,linewidths=linewidths, 
                                   ax=ax,edgecolors=node_edge_color, 
                                   alpha=node_alpha)
        nx.draw_networkx_edges(G = G,pos=pos, edge_color=edge_color, edge_cmap=plt.get_cmap(edge_cmap),
                               alpha=edges_alpha, 
                               width=edge_width, ax=ax)
    else:
        if len(figsize)==0:
            figsize=[4*(len(cat)+int(clustering!="")),4]
        if len(rows_cols)==0:
            rows_cols=[1, len(cat)+int(clustering!="")]
        fig, axes=plt.subplots(figsize=figsize, ncols=rows_cols[1], nrows=rows_cols[0])
        
        axes=axes.flatten()
        
        if bundle==True:
            hb=_bundlle_edges(G, pos)
        
        
        for ax, cat in zip(axes, category):
            #nx.draw(G=G,pos=pos, font_size=8,linewidths=0)
            nx.draw_networkx_nodes(G = G, node_color=colors[cat], 
                                   node_size=node_size,
                                   pos=pos,linewidths=linewidths, 
                                   ax=ax,edgecolors=node_edge_color, 
                                   alpha=node_alpha)
            if show_edges==True:
                nx.draw_networkx_edges(G = G,pos=pos, edge_color=edge_color, edge_cmap=plt.get_cmap(edge_cmap),
                                   alpha=edges_alpha, 
                                   width=edge_width, ax=ax)
                                    #,arrows=True,connectionstyle="arc3,rad=0.3")
            ax.set_title(method+", colored by "+cat)
            legendhandles=[]
            for label, color in colorlut[cat].items():
                legendhandles.append(Line2D([0], [0], color=color,linewidth=5, label=label))
            #g.add_legend(legend_data=legendhandles,title="Aroma",label_order=["W","F","Y"])
            ax.legend(handles=legendhandles, loc='best', title=cat)
            if bundle==True:
                ax.plot(hb.x, hb.y, "y", zorder=1, linewidth=3)
        if clustering!="":
            nx.draw_networkx_nodes(G = G, node_color=colors[clustering], 
                                   node_size=node_size,
                                   pos=pos,linewidths=linewidths, 
                                   ax=axes[-1],edgecolors=node_edge_color, 
                                   alpha=node_alpha)
            if show_edges==True:
                nx.draw_networkx_edges(G = G,pos=pos, edge_color=edge_color, edge_cmap=plt.get_cmap(edge_cmap),
                                   alpha=edges_alpha, 
                                   width=edge_width, ax=axes[-1])
                                    #,arrows=True,connectionstyle="arc3,rad=0.3")
            axes[-1].set_title(method+", colored by "+clustering)
            legendhandles=[]
            for label, color in colorlut[clustering].items():
                legendhandles.append(Line2D([0], [0], color=color,linewidth=5, label=label))
            axes[-1].legend(handles=legendhandles, loc='best', title=clustering)
            if bundle==True:
                axes[-1].plot(hb.x, hb.y, "y", zorder=1, linewidth=3)
    _save(save, "network_correlation")
    return {"axes":axes,"networkx":G, "distance_mat":dmat}
def pienodes(g: igraph.Graph, node_features: Union[pandas.core.frame.DataFrame, Dict[~KT, ~VT]], vertex_label: list = [], pie_frac: str = 'frac', pie_label: str = 'label', pie_palette: Union[str, dict] = 'tab20c', node_label: str = 'all', piesize: Optional[float] = None, label_color: str = 'black', figsize: list = [], save: str = '', **kwargs) ‑> Dict[~KT, ~VT]

Drawing a network whose noses are pie charts.

Parameters

g : igraph object
 
vertex_label : list
The list of node labels. e.g.: nodes=["A","B","C","D","E"]
node_features : dict
A dictionary containing fractions and labels of the pie charts. e.g.: pie_features={"A":{"frac":np.array([50,50]),"label":np.array(["a","b"])}, "B":{"frac":np.array([90,5,5]),"label":np.array(["a","b","c"])}, "C":{"frac":np.array([100]),"label":np.array(["c"])}, "D":{"frac":np.array([100]),"label":np.array(["b"])}, "E":{"frac":np.array([100]),"label":np.array(["a"])}}
pie_frac : str
The key value for the fractions of the pie charts. Default: "frac" (as the example of the above pie_features).
pie_label : str
The key value for the labels of the pie charts. Default: "label" (as the example of the above pie_features).
pie_palette : str or dict
If string is provided, it must be one of the matplotlib colormap names for the pie charts. If dict, then
node_label : str
Whether to label all nodes or not. If partial, labels show up only for 0.05 upper quantile of nodes with a high degree.
piesize : float
Scaling pie chart sizes if they are too large/small.

Returns

axis
 

Raises

Notes

References

See Also

Examples

Expand source code
def pienodes(g: igraph.Graph,  
             node_features: Union[pd.DataFrame, Dict],
             vertex_label: list=[], 
             pie_frac: str="frac",
             pie_label: str="label",
             pie_palette: Union[str , dict]="tab20c", 
             node_label: str="all",
             piesize: Optional[float]=None,
             label_color: str="black",
             figsize: list=[],
             save: str="",
             **kwargs) -> Dict: 
    """
    Drawing a network whose noses are pie charts.
    
    Parameters
    ----------
    g : igraph object
    vertex_label: list
        The list of node labels.
        e.g.: nodes=["A","B","C","D","E"]
    node_features: dict
        A dictionary containing fractions and labels of the pie charts.
        e.g.:
        pie_features={"A":{"frac":np.array([50,50]),"label":np.array(["a","b"])},
                  "B":{"frac":np.array([90,5,5]),"label":np.array(["a","b","c"])},
                  "C":{"frac":np.array([100]),"label":np.array(["c"])},
                  "D":{"frac":np.array([100]),"label":np.array(["b"])},
                  "E":{"frac":np.array([100]),"label":np.array(["a"])}}
    pie_frac : str
        The key value for the fractions of the pie charts. Default: "frac" (as the example of the above pie_features).
    pie_label : str
        The key value for the labels of the pie charts. Default: "label" (as the example of the above pie_features). 
    pie_palette: str or dict
        If string is provided, it must be one of the matplotlib colormap names for the pie charts. If dict, then  
    node_label: str
        Whether to label all nodes or not. If partial, labels show up only for 0.05 upper quantile of nodes with a high degree.
    
    piesize : float
        Scaling pie chart sizes if they are too large/small.    

    Returns
    -------
    axis
    
    Raises
    ------
    Notes
    -----
    References
    ----------
    See Also
    --------
    Examples
    --------
    """#print(kwargs)    
    
    
    
    if type(pie_palette)== str:
        colors={}
        if type(node_features)==dict:
            
            unique_labels=set()
            for k, v in node_features.items():
                for la in v[pie_label]:
                    unique_labels.add(la)
            unique_labels=list(unique_labels)
        else:
            node_features=node_features.pivot_table(index="node",columns=pie_label)
            
            unique_labels=[nf[1] for nf in node_features.columns]
            
        cmap=plt.get_cmap(pie_palette)
        unique_labels=nts(unique_labels)
        labelnum=len(unique_labels)
        for i, ul in enumerate(unique_labels):
            colors[ul]=cmap(i/labelnum)
    elif type(pie_palette)==dict:
        colors=pie_palette
        unique_labels=nts(colors.keys())
    else:
        raise Exception("Unknown pie_palette type.")
    
    pos=np.array(kwargs["layout"].coords)
    posmax=np.amax(pos)
    if len(figsize)==0:
        figsize=[5,5*np.abs(np.amax(pos[:,1]))/np.abs(np.amax(pos[:,0]))]
    fig, ax = plt.subplots(figsize=figsize)
    plt.subplots_adjust(right=0.8)
    mgd=igraph_classes.MatplotlibGraphDrawer(ax)
    
    _deg=np.array([d for d in g.degree()])
    _deg=np.log(_deg+2)
    _deg=0.05*posmax*(_deg/_deg.max())
    #print(_deg)
    degsort=np.argsort(_deg)
    nodes=np.array(vertex_label)
    deg=_deg[degsort]
    nodes=nodes[degsort]
    pos=np.array(kwargs["layout"].coords)[degsort]
    if piesize!=None:
        mgd.draw(g, vertex_size=piesize,alpha=0,**kwargs)
    else:
        mgd.draw(g, vertex_size=1.8*_deg,alpha=0,**kwargs)
    #piesize=0.02
    #print(nodes)
    #print(deg)
    xycoord=[]
    text_pos=[]
    for i, n in enumerate(nodes):
        
        
        x, y=pos[i]
        if type(node_features)==dict:
            frac=node_features[n][pie_frac]
            frac=2*np.pi*np.array(frac)/np.sum(frac)
            
            _colors=[colors[f] for f in node_features[n][pie_label]]
        else:
            _node_features=node_features.loc[n]
            frac=list(_node_features.loc[pie_frac,unique_labels])
            frac=2*np.pi*np.array(frac)/np.sum(frac)
            _colors=[colors[f] for f in unique_labels]
            
        angle=0
        for fr, co in zip(frac, _colors):
            if piesize!=None:
                _baumkuchen_xy(ax, x, y, angle, fr, 0, piesize,20, co)
            else:
                _baumkuchen_xy(ax, x, y, angle, fr, 0, deg[i],20, co)
            angle+=fr
        # xx,yy=trans(pos[i]) # figure coordinates
        # #print(xx,yy, pos[i])
        # xa,ya=trans2((xx,yy)) # axes coordinates
        # a = plt.axes([xa-deg[i]/2,ya-deg[i]/2, deg[i], deg[i]],
        #              rasterized=True,
        #              adjustable="datalim")
        # #a.set_zorder(1)
        # a.set_aspect('equal')
        #
        # a.pie(node_features[n][pie_frac], colors=[colors[f] for f in node_features[n][pie_label]])
        # a.margins(0,0)
        
        if node_label=="all":
            text_pos.append([x, y,n])
        elif node_label=="partial":
            if deg[i] > np.quantile(deg, 0.95):
                text_pos.append([x, y,n])
        elif node_label=="none":
            pass

    for j, (xa,ya,n) in enumerate(text_pos):
        ax.text(xa,ya,n, color=label_color)
        ax.zorder=i+j
    legend_elements = [Line2D([0], [0], marker='o', color='lavender', label=ul,markerfacecolor=colors[ul], markersize=10)
                      for ul in unique_labels]
    
    ax.legend(handles=legend_elements,bbox_to_anchor=(0.95, 1))
    _save(save, "pienodes")
    return {"axes":ax}
def sankey_category(df, category: list = [], palette: str = 'tab20c', colormode: str = 'independent', altcat: str = '', show_percentage=False, show_percentage_target=False, fontsize: int = 12, hatch: bool = False) ‑> matplotlib.axes._axes.Axes

Drawing a sankey plot to compare multiple categories in a data. The usage example may be to compare between clustering results.

Parameters

df : pandas dataframe
it has to contain categorical columns specified by the 'category' option.
category : list
List of column names to compare.
palette : str, optional default("tab20c")
Colormap name. Default: tab20c
colormode : str ['shared', 'independent', 'alternative', 'trace'], optional default("independent")
The way to color categories. 'independent' will give each category a distinct (unrelated) colorset. 'shared' will give a shared color if categories share the same names of labels. 'alternative' will color bars based on additional category specified by altcat. 'trace' will color all bars according to the first category.
altcat : str, optional (but required when colormode='alternative')
 

Returns

axis
 

Raises

Notes

References

See Also

Examples

Expand source code
def sankey_category(df, 
                    category: list=[], 
                    palette: str="tab20c", 
                    colormode: str="independent",
                    altcat: str="",
                    show_percentage=False,
                    show_percentage_target=False,
                    fontsize: int=12,
                    hatch: bool=False) -> plt.Axes:
    #plt.rcParams.update({'font.size': 12})
    """
    Drawing a sankey plot to compare multiple categories in a data. The usage example may be 
    to compare between clustering results.
    
    Parameters
    ----------
    df : pandas dataframe
        it has to contain categorical columns specified by the 'category' option.
    category: list
        List of column names to compare.
    palette: str, optional default("tab20c")
        Colormap name. Default: tab20c
    colormode: str ['shared', 'independent', 'alternative', 'trace'], optional default("independent")
        The way to color categories. 'independent' will give each category a distinct (unrelated) colorset. 
        'shared' will give a shared color if categories share the same names of labels. 'alternative' will
        color bars based on additional category specified by altcat. 'trace' will color all bars according
        to the first category.  
        
    altcat : str, optional (but required when colormode='alternative')
        

    Returns
    -------
    axis
    
    Raises
    ------
    Notes
    -----
    References
    ----------
    See Also
    --------
    Examples
    --------
    """
    if altcat !="":
        df = df.sort_values(category+[altcat])
    else:
        df = df.sort_values(category)
    df=df.reset_index(drop=True)
    
    #print(df)
    blockwidth=0.2
    xinterval=0.5
    space=0.02
    
    link_counts=[]
    catval=df[category].values
    catval=np.array(catval,dtype=str)
    for i in range(len(category)-1):
        s=category[i]
        t=category[i+1]
        sval=catval[:,i]
        tval=catval[:,i+1]
        links=[]
        for _sv, _tv in zip(sval,tval):
            links.append("-->>".join([str(_sv),str(_tv)]))
        ul, cl=np.unique(links, return_counts=True)
        link_counts.append([ul,cl])        
    
    heights={}
    xyh={}
    scolors={}
    for cat in category:
        u, c=np.unique(df[cat], return_counts=True)
        heights[cat]=[np.array(u,dtype=str),c]
    unique_cat=set()
    if colormode=="shared":
        for cat, v in heights.items():
            for c in list(v[0]):
                unique_cat.add(c)
        unique_cat=sorted(list(unique_cat))
        _tmp=plt.get_cmap(palette, len(unique_cat))
        _cmap={unique_cat[i]: _tmp(i) for i in range(len(unique_cat))}
        _hatchmap={unique_cat[i]: hatch_list[i] for i in range(len(unique_cat))}
        cmap={}
        hatchmap={}
        for i, cat in enumerate(category):
            cmap[cat]=_cmap
            hatchmap[cat]=_hatchmap
    elif colormode=="independent":
        cmap={}
        hatchmap={}
        for i, (cat, v) in enumerate(heights.items()):
            _tmp=plt.get_cmap(palette, v[0].shape[0])
            cmap[cat]={v[0][i]: _tmp(i) for i in range(v[0].shape[0])}
            hatchmap[cat]={v[0][i]: hatch_list[i] for i in range(v[0].shape[0])}

    elif colormode=="trace":
        altcat=category[0]
        altcat_list=list(df[altcat])
        altcat_unique=list(np.unique(altcat_list))
        _tmp=plt.get_cmap(palette, len(altcat_unique))

        altcat_dict={}

        hatchmap={}
        for i, a in enumerate(altcat_unique):
            altcat_dict[a]=_tmp(altcat_unique.index(a))
            hatchmap[a]=linestyles[i][1]

        altcat_colors=[]
        hatches=[]
        for a in altcat_list:
            altcat_colors.append(altcat_dict[a])
            hatches.append(hatchmap[a])


    elif colormode=="alternative":
        if altcat=="":
            raise Exception("If colormode is 'alternative', altcat must be specified")
        altcat_list=list(df[altcat])
        altcat_unique=list(np.unique(altcat_list))
        _tmp=plt.get_cmap(palette, len(altcat_unique))

        altcat_dict={}
        hatchmap={}
        for i, a in enumerate(altcat_unique):
            altcat_dict[a]=_tmp(altcat_unique.index(a))
            hatchmap[a]=linestyles[i][1]
            
        altcat_colors=[]
        hatches=[]
        for a in altcat_list:
            altcat_colors.append(altcat_dict[a])
            hatches.append(hatchmap[a])
    
    fig, ax=plt.subplots(figsize=[2+len(category),7])
    blocks=[]
    facecolors=[]
    facehatches=[]
    hs=[]
    for i,(cat, (u, ac)) in enumerate(heights.items()):
        xyh[cat]={}
        c=ac/np.sum(ac)
        h=0
        for _u, _c, _ac in zip(u,c,ac):
            blocks.append(Rectangle([i*xinterval, h],blockwidth,_c))
            if colormode=="alternative" or colormode=="trace":
                facecolors.append([1,1,1,0])
            else:
                facecolors.append(cmap[cat][_u])
                if hatch==True:
                    facehatches.append(hatchmap[cat][_u])
            h+=_c+space
            xyh[cat][str(_u)]=[i*xinterval, h, _c,_ac]
            ax.text(i*xinterval+blockwidth/2,h-space-_c/2, _u,
                    bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="b", lw=1, alpha=0.7))
        hs.append(h)
    if colormode=="alternative" or colormode=="trace":
        
        for i,(cat, (u, ac)) in enumerate(heights.items()):
            k=0
            for j, _u in enumerate(u):
                dfcat=np.array(df[cat], dtype=str)
                _df=df.loc[dfcat==_u]
                _df=_df.sort_values(altcat)                
                x, h, _c, _ac=xyh[cat][_u]
                if hatch==True:
                    for l, label in enumerate(_df[altcat]):
                        plt.plot([x,x+blockwidth],[h-_c-space+_c*l/_ac,h-_c-space+_c*l/_ac], 
                                 color=altcat_dict[label],
                                linestyle=hatchmap[label])
                        k+=1
                else:
                    for l, label in enumerate(_df[altcat]):
                        plt.plot([x,x+blockwidth],[h-_c-space+_c*l/_ac,h-_c-space+_c*l/_ac], color=altcat_dict[label])
                        k+=1
        if hatch==True:
            plt.legend([Line2D([0], [0], color=altcat_dict[label],linewidth=2,
                                linestyle=hatchmap[label]) for label in altcat_dict.keys()],
                   altcat_dict.keys(),loc=[1.01,0.9], fontsize=fontsize)
        else:
            plt.legend([Line2D([0], [0], color=altcat_dict[label]) for label in altcat_dict.keys()],
                   altcat_dict.keys(),loc=[1.01,0.9], fontsize=fontsize)
        plt.subplots_adjust(right=0.7)
        
    # draw links between categories
    resolution=100 # resolution of lines for the links
    for i, (ul, cl) in enumerate(link_counts):
        scat=category[i]
        tcat=category[i+1]
        sbottom={}
        tbottom={}
        scats=set()
        tcats=set()
        for _ul, _cl in zip(ul, cl):
            
            
            s, t=_ul.split("-->>")
            sx, sy, sh, _=xyh[scat][s]
            tx, ty, th, _=xyh[tcat][t]

            if not t in tbottom:
                tbottom[t]=ty-space-th
            if not s in sbottom:
                sbottom[s]=sy-space-sh
            _scl=sh*_cl/(heights[scat][1][heights[scat][0]==s])[0]
            _tcl=th*_cl/(heights[tcat][1][heights[tcat][0]==t])[0]
            
            # morphing links by convolution
            xconv=np.linspace(sx+blockwidth, tx,resolution-20*2+2)
            byconv=np.array((resolution//2) * [sbottom[s]] + (resolution//2) * [tbottom[t]])
            byconv = np.convolve(byconv, 0.05 * np.ones(20), mode='valid')
            byconv = np.convolve(byconv, 0.05 * np.ones(20), mode='valid')
            #ax.plot(xconv,yconv)
            tyconv=np.array((resolution//2) * [_scl+sbottom[s]] + (resolution//2) * [tbottom[t]+_tcl])
            tyconv = np.convolve(tyconv, 0.05 * np.ones(20), mode='valid')
            tyconv = np.convolve(tyconv, 0.05 * np.ones(20), mode='valid')
            
            #ax.plot(xconv,yconv)
            plt.fill_between(
                xconv, byconv, tyconv, alpha=0.65,
                color="b"
            )
            
            if show_percentage==True:
                plt.text(sx, sh/2+sy-space-sh,str(np.round(100*sh,1))+"%",ha="right",
                         va="center",
                         rotation=90,
                         bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="y", lw=1, alpha=0.8))
                if i==len(link_counts)-1:
                    plt.text(tx, th/2+ty-space-th,str(np.round(100*th,1))+"%",ha="right",
                             va="center",
                             rotation=90,
                             bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="y", lw=1, alpha=0.8))
            if show_percentage_target==True:
                plt.text(sx+blockwidth, _scl/2+sbottom[s],
                         str(np.round(100*_scl,1))+"%",
                         ha="left",va="center",
                         rotation=90,
                         bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="y", lw=1, alpha=0.8))
                
            sbottom[s]+=_scl
            tbottom[t]+=_tcl
            

    if hatch==True:
        for b, c, h in zip(blocks, facecolors, facehatches):
            pc = PatchCollection([b],facecolor=[c],edgecolor="black",linewidth=2, hatch=h)
            ax.add_collection(pc)
    else:
        pc = PatchCollection(blocks,facecolor=facecolors,edgecolor="black",linewidth=2)
        ax.add_collection(pc)

    
    
    
    
    plt.xlim(-blockwidth*0.5, xinterval*(len(heights)-1)+blockwidth*1.5)
    plt.ylim(-0.01, np.amax(hs)*1.1)
    plt.xticks([i*xinterval+blockwidth/2 for i in range(len(category))],category, rotation=90, 
               fontsize=fontsize)
    plt.yticks([])
    plt.subplots_adjust(bottom=0.2)
    return ax
def sankey_flow(df)
Expand source code
def sankey_flow(df):
    
    pass