adding row colors to a heatmap

Question:

I’m trying to visualize different groups within my data, on the axis of a heatmap (using sns.heatmap). I want the ticks to be categorized according to a specific dictionary, with names and colors, and eventually to be presented in a legend.

I know this can be obtained by using sns.clustermap but this function also clusters the values in the heatmap, which I don’t want to happen.

Any idea how can I make such visualization in a heatmap?

Example

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import random 
import seaborn as sns


data = pd.DataFrame(np.random.randn(10, 3), columns=(list('ABC')))
data['group'] = pd.Series(random.choices(['group_1', 'group_2', 'group_3'], weights=[1,1,1], k=len(data)))

data
          A         B         C    group
0  0.366822  0.583965  1.629740  group_3
1  0.557286  0.450663  0.255852  group_3
2 -0.265515 -0.153028  0.670448  group_3
3  0.132278 -0.226668  1.365583  group_3
4  0.595304 -0.577290  0.395477  group_2
5 -0.805420  0.168376  0.748649  group_1
6  0.105664 -0.568047 -0.281488  group_2
7 -0.046202  0.173409 -0.250321  group_1
8 -0.132696 -0.877354  0.086954  group_3
9 -0.843666  0.655146 -1.629453  group_2

lut = {'group_1': 'red', 'group_2': 'blue', 'group_3': 'green'}

row_colors = data['group'].map(lut)

data.drop(['group'], axis=1, inplace=True)

fig = sns.heatmap(data)

I aim to use row_colors to visually indicate which indices correspond to which group, in the heatmap to be.

any help regarding this would be highly appreciated, either directly from seaborn, or somehow else 🙂

Asked By: Shuman_tov

||

Answers:

You might try a clustermap without clusters:

g = sns.clustermap(data=data.drop(['group'], axis=1), row_colors=row_colors, row_cluster=False, col_cluster=False, dendrogram_ratio=0.05, cbar_pos=None)

Or you might explicetly draw little rectangles next to the heatmap.
In the code below, the drop isn’t executed inplace` to be able to use the groups column afterwards.

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

group_names = ['group_1', 'group_2', 'group_3']
data = pd.DataFrame(np.random.randn(10, 3), columns=(list('ABC')))
data['group'] = pd.Series(np.random.choice(group_names, p=[1/3, 1/3, 1/3], size=len(data)))

lut = {'group_1': 'red', 'group_2': 'blue', 'group_3': 'green'}
row_colors = data['group'].map(lut)

ax = sns.heatmap(data.drop(['group'], axis=1))
ax.tick_params(axis='y', which='major', pad=20, length=0) # extra padding to leave room for the row colors
ax.set_yticklabels(data['group'], rotation=0) # optionally use the groups as the tick labels
for i, color in enumerate(row_colors):
    ax.add_patch(plt.Rectangle(xy=(-0.05, i), width=0.05, height=1, color=color, lw=0,
                               transform=ax.get_yaxis_transform(), clip_on=False))
plt.tight_layout()
plt.show()

sns.heatmap with row colors

Answered By: JohanC