How to add labels to the side color bar in clustermap in Seaborn/Python

Question:

I have written a Python script as follows to plot a clustermap.

import sys
import importlib
import matplotlib.pyplot as plt

# import PRCC function
import PRCC as prcc
import QSP_analysis as qa

#%%
import numpy as np
from pyDOE2 import lhs
# Reading data
num_samples = 20
num_param = 15
num_readout = 9

header, data = qa.read_csv('3.csv')
param_names = header[1:num_param+1]
read_names = header[num_param+1:]
lhd = data[:,1:num_param+1].astype(float)
readout = data[:,num_param+1:].astype(float)


Rho, Pval, Sig, Pval_correct = prcc.partial_corr(lhd, readout, 1e-14,Type = 'Spearman', MTC='Bonferroni')

sig_txt = np.zeros((num_param, num_readout), dtype='U8')
sig_txt[Pval_correct<5e-2] = '*'
sig_txt[Pval_correct<1e-6] = '**'
sig_txt[Pval_correct<1e-9] = '***'
param_group = ["beige"]*10 + ["khaki"]*(num_param-10)
readout_group = ["#B4B4FF"]*2+ ["mediumslateblue"]*(num_readout-2)
importlib.reload(qa)
cm = qa.cluster_map(np.transpose(Pval), read_names,param_names, 
                  (10,6), cmap="bwr",
                  annot=np.transpose(sig_txt),
                  row_colors = readout_group,
                  col_colors = param_group,
                  col_cluster=False, row_cluster=False,
                  show_dendrogram = [False, False])
cm.savefig('heat.png',format='png', dpi=600,bbox_inches='tight')

The functions in the code are as below:

def read_csv(filename, header_line = 1, dtype = str):
    with open(filename) as csvfile:
        reader = csv.reader(csvfile, delimiter=',', quotechar='|')
        # header
        header = ''
        for i in range(header_line):
            header = next(reader)
        # data
        data = np.asarray(list(reader), dtype = dtype)
        return header, data  

def cluster_map(data, row_label, col_label, fig_size, annot = None,
                show_dendrogram = [True, True], **kwarg):
    
    df = pd.DataFrame(data=data, index = row_label, columns = col_label)
    g = sns.clustermap(df, annot = annot, fmt = '',
                       vmin=-1, vmax=1, cbar_kws={"ticks":[-1, -.5,  0, .5,  1]}, **kwarg)
    #row_order = g.dendrogram_row.reordered_ind
    #col_order = g.dendrogram_col.reordered_ind
    g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), rotation=0)
    g.ax_heatmap.set_xticklabels(g.ax_heatmap.get_xticklabels(), rotation=-55, ha = 'left')
    g.ax_row_dendrogram.set_visible(show_dendrogram[0])
    g.ax_col_dendrogram.set_visible(show_dendrogram[1])
    g.fig.set_size_inches(*fig_size)
    return g

The PRCC function is a script to call MATLAB and calculate the PRCC value. The main scrip reads a csv file with 20 rows and 24 columns with different headers. The output of the code is a clustermap based on some columns (vertical:read_names and horizontal:param_names).

I have added a color bar to categorize the variables on the horizontal and vertical axes. The output figure is like below. How can add labels to these color-bar: for the horizontal one (ABM" and "QSP") and for the vertical one ("endpoint" and "pretreatment")?

Cluster-map from the script

Asked By: Nick

||

Answers:

The following example code supposes you are calling sns.clustermap with row_cluster=False, col_cluster=False), so the rows and columns stay in their original order (if they get reordered, the original groups will get separated).

groupby from itertools can be used to calculate run lengths of the lists of colors. Their cumulative sums indicate the border between colors. Averaging these positions suits to place the labels.

from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
from itertools import groupby

def cluster_map(data, row_label, col_label, fig_size, annot=None, row_color_labels=None, col_color_labels=None,
                show_dendrogram=[True, True], **kwarg):
    df = pd.DataFrame(data=data, index=row_label, columns=col_label)
    g = sns.clustermap(df, annot=annot, fmt='',
                       vmin=-1, vmax=1, cbar_kws={"ticks": [-1, -.5, 0, .5, 1]}, **kwarg)
    g.ax_heatmap.set_yticklabels(g.ax_heatmap.get_yticklabels(), va='center')
    if row_color_labels is not None:
        row_colors = kwarg['row_colors']
        borders = np.cumsum([0] + [sum(1 for i in g) for k, g in groupby(row_colors)])
        for b0, b1, label in zip(borders[:-1], borders[1:], row_color_labels):
            g.ax_row_colors.text(-0.06, (b0 + b1) / 2, label, color='black', ha='right', va='center', rotation=90,
                                 transform=g.ax_row_colors.get_yaxis_transform())
    if col_color_labels is not None:
        col_colors = kwarg['col_colors']
        borders = np.cumsum([0] + [sum(1 for i in g) for k, g in groupby(col_colors)])
        for b0, b1, label in zip(borders[:-1], borders[1:], col_color_labels):
            g.ax_col_colors.text((b0 + b1) / 2, 1.06, label, color='black', ha='center', va='bottom',
                                 transform=g.ax_col_colors.get_xaxis_transform())

cluster_map(np.random.uniform(-1, 1, size=(7, 12)),
            fig_size=(12, 12),
            col_label=[*'ABCDEFGHIJKL'],
            row_label=['Alkaid', 'Mizar', 'Alioth', 'Megrez', 'Phecda', 'Merak', 'Dubhe'],
            col_colors=["beige"] * 10 + ["khaki"] * (12 - 10),
            row_colors=["#B4B4FF"] * 2 + ["mediumslateblue"] * (7 - 2),
            col_color_labels=["ABM", "QSP"],
            row_color_labels=["endpoint", "pretreatment"],
            row_cluster=False,
            col_cluster=False)

sns.clustermap with extra labels

Answered By: JohanC

You can plot very complex heatmaps from data frame using a python package PyComplexHeatmap: https://github.com/DingWB/PyComplexHeatmap
https://github.com/DingWB/PyComplexHeatmap/blob/main/notebooks/examples.ipynb

PyComplexHeatmap

Answered By: Odin
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.