How to sharex when using subplot2grid

Question:

I’m a Matlab user recently converted to Python. Most of the Python skills I manage on my own, but with plotting I have hit the wall and need some help.

This is what I’m trying to do…

I need to make a figure that consists of 3 subplots with following properties:

  • subplot layout is 311, 312, 313
  • the height of 312 and 313 is approximately half of the 311
  • all subplots share common X axis
  • the space between the subplots is 0 (they touch each other at X axis)

By the way I know how to make all this, only not in a single figure. That is the problem I’m facing now.

For example, this is my ideal subplot layout:

import numpy as np
import matplotlib.pyplot as plt

t = np.arange(0.0, 2.0, 0.01)

s1 = np.sin(2*np.pi*t)
s2 = np.exp(-t)
s3 = s1*s2

fig = plt.figure()
ax1 = plt.subplot2grid((4,3), (0,0), colspan=3, rowspan=2)
ax2 = plt.subplot2grid((4,3), (2,0), colspan=3)
ax3 = plt.subplot2grid((4,3), (3,0), colspan=3)

ax1.plot(t,s1)
ax2.plot(t[:150],s2[:150])
ax3.plot(t[30:],s3[30:])

plt.tight_layout()

plt.show()

Notice how the x axis of different subplots is misaligned. I do not know how to align the x axis in this figure, but if I do something like this:

import numpy as np
import matplotlib.pyplot as plt

t = np.arange(0.0, 2.0, 0.01)

s1 = np.sin(2*np.pi*t)
s2 = np.exp(-t)
s3 = s1*s2

fig2, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1, sharex=True)

ax1.plot(t,s1)
ax2.plot(t[:150],s2[:150])
ax3.plot(t[30:],s3[30:])

plt.tight_layout()

plt.show()

Now the x axis is aligned between the subplots, but all subplots are the same size (which is not what I want)

Furthermore, I would like that the subplots are touching at x axis like this:

import numpy as np
import matplotlib.pyplot as plt

t = np.arange(0.0, 2.0, 0.01)

s1 = np.sin(2*np.pi*t)
s2 = np.exp(-t)
s3 = s1*s2

fig1 = plt.figure()
plt.subplots_adjust(hspace=0)

ax1 = plt.subplot(311)
ax2 = plt.subplot(312, sharex=ax1)
ax3 = plt.subplot(313, sharex=ax1)

ax1.plot(t,s1)
ax2.plot(t[:150],s2[:150])
ax3.plot(t[30:],s3[30:])

xticklabels = ax1.get_xticklabels()+ax2.get_xticklabels()
plt.setp(xticklabels, visible=False)

plt.show()

So to rephrase my question:

I would like to use

plt.subplot2grid(..., colspan=3, rowspan=2)
plt.subplots(..., sharex=True)
plt.subplots_adjust(hspace=0)

and

plt.tight_layout()

together in the same figure. How to do that?

Asked By: Boris L.

||

Answers:

Just specify sharex=ax1 when creating your second and third subplots.

import numpy as np
import matplotlib.pyplot as plt

t = np.arange(0.0, 2.0, 0.01)

s1 = np.sin(2*np.pi*t)
s2 = np.exp(-t)
s3 = s1*s2

fig = plt.figure()
ax1 = plt.subplot2grid((4,3), (0,0), colspan=3, rowspan=2)
ax2 = plt.subplot2grid((4,3), (2,0), colspan=3, sharex=ax1)
ax3 = plt.subplot2grid((4,3), (3,0), colspan=3, sharex=ax1)

ax1.plot(t,s1)
ax2.plot(t[:150],s2[:150])
ax3.plot(t[30:],s3[30:])

fig.subplots_adjust(hspace=0)   
for ax in [ax1, ax2]:
    plt.setp(ax.get_xticklabels(), visible=False)
    # The y-ticks will overlap with "hspace=0", so we'll hide the bottom tick
    ax.set_yticks(ax.get_yticks()[1:])  

plt.show()

enter image description here

If you still what to use fig.tight_layout(), you’ll need to call it before fig.subplots_adjust(hspace=0). The reason for this is that tight_layout works by automatically calculating parameters for subplots_adjust and then calling it, so if subplots_adjust is manually called first, anything in the first call to it will be overridden by tight_layout.

E.g.

fig.tight_layout()
fig.subplots_adjust(hspace=0)
Answered By: Joe Kington

A possible solution is to manually create the axis using the add_axis method like shown here:

import numpy as np
import matplotlib.pyplot as plt

t = np.arange(0.0, 2.0, 0.01)

s1 = np.sin(2*np.pi*t)
s2 = np.exp(-t)
s3 = s1*s2

left, width = 0.1, 0.8
rect1 = [left, 0.5, width, 0.4]
rect2 = [left, 0.3, width, 0.15]
rect3 = [left, 0.1, width, 0.15]

fig = plt.figure()
ax1 = fig.add_axes(rect1)  #left, bottom, width, height
ax2 = fig.add_axes(rect2, sharex=ax1)
ax3  = fig.add_axes(rect3, sharex=ax1)

ax1.plot(t,s1)
ax2.plot(t[:150],s2[:150])
ax3.plot(t[30:],s3[30:])

# hide labels
for label1,label2 in zip(ax1.get_xticklabels(),ax2.get_xticklabels()):
    label1.set_visible(False)
    label2.set_visible(False)

plt.show()

But this way you cannot use tight_layout as you explicitly define the size of each axis.

Answered By: Jakob
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.