custom legend function for matplotlib chart
Question:
Right now when I have dual y-axis and I want to have one legend box for both of my labels I do the following:
fig, ax = plt.subplots(1,2)
ax0=ax[0].twinx()
line1=ax[0].plot([1, 2, 3], [4, 5, 6], 'b', lw=1, label='Line 1')
line2=ax0.plot([1, 2, 3], [6, 5, 4], 'r', lw=1, label='Line 2')
lines=line1+ line2
labels=[l.get_label() for l in lines]
ax[0].legend(lines, labels)
ax1=ax[1].twinx()
line1=ax[1].plot([1, 2, 3], [4, 5, 6], 'b--', lw=1, label='Line 3')
line2=ax1.plot([1, 2, 3], [6, 5, 4], 'r--', lw=1, label='Line 4')
lines=line1+line2
labels=[l.get_label() for l in lines]
ax[1].legend(lines, labels)
I want to create a custom legend function that would allow my code to cut down on the repetetive task of specifying each line, then getting the label, etc. So the final code should look like this:
fig, ax = plt.subplots(1,2)
ax0=ax[0].twinx()
ax[0].plot([1, 2, 3], [4, 5, 6], 'b', lw=1, label='Line 1')
ax0.plot([1, 2, 3], [6, 5, 4], 'r', lw=1, label='Line 2')
legend()
ax1=ax[1].twinx()
ax[1].plot([1, 2, 3], [4, 5, 6], 'b--', lw=1, label='Line 3')
ax1.plot([1, 2, 3], [6, 5, 4], 'r--', lw=1, label='Line 4')
legend()
I couldn’t make my legend function grab labels from both lines, does anyone have any advice on how that can be achieved. This is what I have so far:
def legend(ax=None,**kwargs):
if ax is None:
ax=plt.gca()
lines=ax.get_lines()
labels=[line.get_label() for line in lines]
return ax.legend(lines,labels,frameon=False,**kwargs)
I also don’t always use the same naming when I create twinx(), sometimes I would do ax0=ax[0].twinx()
, other times I would call it temp=ax[0].twinx()
and so on. So ideally I need the legend function to work irrespective of what naming I use.
Thank you.
Answers:
You can create a function that you pass in a list of the axes and it will create a legend on the first Axes
object in that list.
import matplotlib.pyplot as plt
plt.close("all")
def legend(axes):
lines = []
for ax in axes:
lines.extend(ax.get_lines())
labels = [l.get_label() for l in lines]
axes[0].legend(lines, labels)
fig, ax = plt.subplots(1, 2)
ax0 = ax[0].twinx()
line1 = ax[0].plot([1, 2, 3], [4, 5, 6], 'b', lw=1, label='Line 1')
line2 = ax0.plot([1, 2, 3], [6, 5, 4], 'r', lw=1, label='Line 2')
legend([ax[0], ax0])
ax1 = ax[1].twinx()
line1 = ax[1].plot([1, 2, 3], [4, 5, 6], 'b--', lw=1, label='Line 3')
line2 = ax1.plot([1, 2, 3], [6, 5, 4], 'r--', lw=1, label='Line 4')
legend([ax[1], ax1])
fig.tight_layout()
fig.show()
Right now when I have dual y-axis and I want to have one legend box for both of my labels I do the following:
fig, ax = plt.subplots(1,2)
ax0=ax[0].twinx()
line1=ax[0].plot([1, 2, 3], [4, 5, 6], 'b', lw=1, label='Line 1')
line2=ax0.plot([1, 2, 3], [6, 5, 4], 'r', lw=1, label='Line 2')
lines=line1+ line2
labels=[l.get_label() for l in lines]
ax[0].legend(lines, labels)
ax1=ax[1].twinx()
line1=ax[1].plot([1, 2, 3], [4, 5, 6], 'b--', lw=1, label='Line 3')
line2=ax1.plot([1, 2, 3], [6, 5, 4], 'r--', lw=1, label='Line 4')
lines=line1+line2
labels=[l.get_label() for l in lines]
ax[1].legend(lines, labels)
I want to create a custom legend function that would allow my code to cut down on the repetetive task of specifying each line, then getting the label, etc. So the final code should look like this:
fig, ax = plt.subplots(1,2)
ax0=ax[0].twinx()
ax[0].plot([1, 2, 3], [4, 5, 6], 'b', lw=1, label='Line 1')
ax0.plot([1, 2, 3], [6, 5, 4], 'r', lw=1, label='Line 2')
legend()
ax1=ax[1].twinx()
ax[1].plot([1, 2, 3], [4, 5, 6], 'b--', lw=1, label='Line 3')
ax1.plot([1, 2, 3], [6, 5, 4], 'r--', lw=1, label='Line 4')
legend()
I couldn’t make my legend function grab labels from both lines, does anyone have any advice on how that can be achieved. This is what I have so far:
def legend(ax=None,**kwargs):
if ax is None:
ax=plt.gca()
lines=ax.get_lines()
labels=[line.get_label() for line in lines]
return ax.legend(lines,labels,frameon=False,**kwargs)
I also don’t always use the same naming when I create twinx(), sometimes I would do ax0=ax[0].twinx()
, other times I would call it temp=ax[0].twinx()
and so on. So ideally I need the legend function to work irrespective of what naming I use.
Thank you.
You can create a function that you pass in a list of the axes and it will create a legend on the first Axes
object in that list.
import matplotlib.pyplot as plt
plt.close("all")
def legend(axes):
lines = []
for ax in axes:
lines.extend(ax.get_lines())
labels = [l.get_label() for l in lines]
axes[0].legend(lines, labels)
fig, ax = plt.subplots(1, 2)
ax0 = ax[0].twinx()
line1 = ax[0].plot([1, 2, 3], [4, 5, 6], 'b', lw=1, label='Line 1')
line2 = ax0.plot([1, 2, 3], [6, 5, 4], 'r', lw=1, label='Line 2')
legend([ax[0], ax0])
ax1 = ax[1].twinx()
line1 = ax[1].plot([1, 2, 3], [4, 5, 6], 'b--', lw=1, label='Line 3')
line2 = ax1.plot([1, 2, 3], [6, 5, 4], 'r--', lw=1, label='Line 4')
legend([ax[1], ax1])
fig.tight_layout()
fig.show()