custom legend function for matplotlib chart

Question:

Right now when I have dual y-axis and I want to have one legend box for both of my labels I do the following:

fig, ax = plt.subplots(1,2)
ax0=ax[0].twinx()
line1=ax[0].plot([1, 2, 3], [4, 5, 6], 'b', lw=1, label='Line 1')
line2=ax0.plot([1, 2, 3], [6, 5, 4], 'r', lw=1, label='Line 2')
lines=line1+ line2
labels=[l.get_label() for l in lines]
ax[0].legend(lines, labels)

ax1=ax[1].twinx()
line1=ax[1].plot([1, 2, 3], [4, 5, 6], 'b--', lw=1, label='Line 3')
line2=ax1.plot([1, 2, 3], [6, 5, 4], 'r--', lw=1, label='Line 4')
lines=line1+line2
labels=[l.get_label() for l in lines]
ax[1].legend(lines, labels)

I want to create a custom legend function that would allow my code to cut down on the repetetive task of specifying each line, then getting the label, etc. So the final code should look like this:

fig, ax = plt.subplots(1,2)
ax0=ax[0].twinx()
ax[0].plot([1, 2, 3], [4, 5, 6], 'b', lw=1, label='Line 1')
ax0.plot([1, 2, 3], [6, 5, 4], 'r', lw=1, label='Line 2')
legend()

ax1=ax[1].twinx()
ax[1].plot([1, 2, 3], [4, 5, 6], 'b--', lw=1, label='Line 3')
ax1.plot([1, 2, 3], [6, 5, 4], 'r--', lw=1, label='Line 4')
legend()

I couldn’t make my legend function grab labels from both lines, does anyone have any advice on how that can be achieved. This is what I have so far:

def legend(ax=None,**kwargs):
    if ax is None:
        ax=plt.gca()        
    lines=ax.get_lines()
    labels=[line.get_label() for line in lines]
    return ax.legend(lines,labels,frameon=False,**kwargs)

I also don’t always use the same naming when I create twinx(), sometimes I would do ax0=ax[0].twinx(), other times I would call it temp=ax[0].twinx() and so on. So ideally I need the legend function to work irrespective of what naming I use.

Thank you.

Asked By: user20856754

||

Answers:

You can create a function that you pass in a list of the axes and it will create a legend on the first Axes object in that list.

import matplotlib.pyplot as plt

plt.close("all")

def legend(axes):
    lines = []
    for ax in axes:
        lines.extend(ax.get_lines())
    labels = [l.get_label() for l in lines]
    axes[0].legend(lines, labels)

fig, ax = plt.subplots(1, 2)
ax0 = ax[0].twinx()
line1 = ax[0].plot([1, 2, 3], [4, 5, 6], 'b', lw=1, label='Line 1')
line2 = ax0.plot([1, 2, 3], [6, 5, 4], 'r', lw=1, label='Line 2')
legend([ax[0], ax0])

ax1 = ax[1].twinx()
line1 = ax[1].plot([1, 2, 3], [4, 5, 6], 'b--', lw=1, label='Line 3')
line2 = ax1.plot([1, 2, 3], [6, 5, 4], 'r--', lw=1, label='Line 4')
legend([ax[1], ax1])

fig.tight_layout()
fig.show()

Answered By: jared
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.