pyplot combine multiple line labels in legend
Question:
I have data that results in multiple lines being plotted, I want to give these lines a single label in my legend. I think this can be better demonstrated using the example below,
a = np.array([[ 3.57, 1.76, 7.42, 6.52],
[ 1.57, 1.2 , 3.02, 6.88],
[ 2.23, 4.86, 5.12, 2.81],
[ 4.48, 1.38, 2.14, 0.86],
[ 6.68, 1.72, 8.56, 3.23]])
plt.plot(a[:,::2].T, a[:, 1::2].T, 'r', label='data_a')
plt.legend(loc='best')
As you can see at Out[23] the plot resulted in 5 distinct lines. The resulting plot looks like this
Is there any way that I can tell the plot method to avoid multiple labels? I don’t want to use custom legend (where you specify the label and the line shape all at once) as much as I can.
Answers:
So using will’s suggestion and another question here, I am leaving my remedy here
handles, labels = plt.gca().get_legend_handles_labels()
i =1
while i<len(labels):
if labels[i] in labels[:i]:
del(labels[i])
del(handles[i])
else:
i +=1
plt.legend(handles, labels)
And the new plot looks like,
I’d make a small helper function personally, if i planned on doing it often;
from matplotlib import pyplot
import numpy
a = numpy.array([[ 3.57, 1.76, 7.42, 6.52],
[ 1.57, 1.2 , 3.02, 6.88],
[ 2.23, 4.86, 5.12, 2.81],
[ 4.48, 1.38, 2.14, 0.86],
[ 6.68, 1.72, 8.56, 3.23]])
def plotCollection(ax, xs, ys, *args, **kwargs):
ax.plot(xs,ys, *args, **kwargs)
if "label" in kwargs.keys():
#remove duplicates
handles, labels = pyplot.gca().get_legend_handles_labels()
newLabels, newHandles = [], []
for handle, label in zip(handles, labels):
if label not in newLabels:
newLabels.append(label)
newHandles.append(handle)
pyplot.legend(newHandles, newLabels)
ax = pyplot.subplot(1,1,1)
plotCollection(ax, a[:,::2].T, a[:, 1::2].T, 'r', label='data_a')
plotCollection(ax, a[:,1::2].T, a[:, ::2].T, 'b', label='data_b')
pyplot.show()
An easier (and IMO clearer) way to remove duplicates (than what you have) from the handles
and labels
of the legend is this:
handles, labels = pyplot.gca().get_legend_handles_labels()
newLabels, newHandles = [], []
for handle, label in zip(handles, labels):
if label not in newLabels:
newLabels.append(label)
newHandles.append(handle)
pyplot.legend(newHandles, newLabels)
Matplotlib gives you a nice interface to collections of lines, LineCollection. The code is straight forward
import numpy
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
a = numpy.array([[ 3.57, 1.76, 7.42, 6.52],
[ 1.57, 1.2 , 3.02, 6.88],
[ 2.23, 4.86, 5.12, 2.81],
[ 4.48, 1.38, 2.14, 0.86],
[ 6.68, 1.72, 8.56, 3.23]])
xs = a[:,::2]
ys = a[:, 1::2]
lines = LineCollection([list(zip(x,y)) for x,y in zip(xs, ys)], label='data_a')
f, ax = plt.subplots(1, 1)
ax.add_collection(lines)
ax.legend()
ax.set_xlim([xs.min(), xs.max()]) # have to set manually
ax.set_ylim([ys.min(), ys.max()])
plt.show()
Numpy solution based on will’s response above.
import numpy as np
import matplotlib.pylab as plt
a = np.array([[3.57, 1.76, 7.42, 6.52],
[1.57, 1.20, 3.02, 6.88],
[2.23, 4.86, 5.12, 2.81],
[4.48, 1.38, 2.14, 0.86],
[6.68, 1.72, 8.56, 3.23]])
plt.plot(a[:,::2].T, a[:, 1::2].T, 'r', label='data_a')
handles, labels = plt.gca().get_legend_handles_labels()
Assuming that equal labels have equal handles, get unique labels and their respective indices, which correspond to handle indices.
labels, ids = np.unique(labels, return_index=True)
handles = [handles[i] for i in ids]
plt.legend(handles, labels, loc='best')
plt.show()
A low tech solution is to make two plot calls. One that plots your data and a second one that plots nothing but carries the handle:
a = np.array([[ 3.57, 1.76, 7.42, 6.52],
[ 1.57, 1.2 , 3.02, 6.88],
[ 2.23, 4.86, 5.12, 2.81],
[ 4.48, 1.38, 2.14, 0.86],
[ 6.68, 1.72, 8.56, 3.23]])
plt.plot(a[:,::2].T, a[:, 1::2].T, 'r')
plt.plot([],[], 'r', label='data_a')
plt.legend(loc='best')
Here’s the result:
I would do this trick:
for i in range(len(a)):
plt.plot(a[i,::2].T, a[i, 1::2].T, 'r', label='data_a' if i==0 else None)
The easiest and most pythonic way to remove duplicates is to use the keys of a dict which are guaranteed to be unique. This also ensures that we only iterate over each of the (handle, label) pairs once.
handles, labels = plt.gca().get_legend_handles_labels()
# labels will be the keys of the dict, handles will be values
temp = {k:v for k,v in zip(labels, handles)}
plt.legend(temp.values(), temp.keys(), loc='best')
I found short way to solve this:
a = np.array([[ 3.57, 1.76, 7.42, 6.52],
[ 1.57, 1.2 , 3.02, 6.88],
[ 2.23, 4.86, 5.12, 2.81],
[ 4.48, 1.38, 2.14, 0.86],
[ 6.68, 1.72, 8.56, 3.23]])
p1=plt.plot(a[:,::2].T, a[:, 1::2].T, color='r')
plt.legend([p1[0]],['data_a'],loc='best')
I have data that results in multiple lines being plotted, I want to give these lines a single label in my legend. I think this can be better demonstrated using the example below,
a = np.array([[ 3.57, 1.76, 7.42, 6.52],
[ 1.57, 1.2 , 3.02, 6.88],
[ 2.23, 4.86, 5.12, 2.81],
[ 4.48, 1.38, 2.14, 0.86],
[ 6.68, 1.72, 8.56, 3.23]])
plt.plot(a[:,::2].T, a[:, 1::2].T, 'r', label='data_a')
plt.legend(loc='best')
As you can see at Out[23] the plot resulted in 5 distinct lines. The resulting plot looks like this
Is there any way that I can tell the plot method to avoid multiple labels? I don’t want to use custom legend (where you specify the label and the line shape all at once) as much as I can.
So using will’s suggestion and another question here, I am leaving my remedy here
handles, labels = plt.gca().get_legend_handles_labels()
i =1
while i<len(labels):
if labels[i] in labels[:i]:
del(labels[i])
del(handles[i])
else:
i +=1
plt.legend(handles, labels)
And the new plot looks like,
I’d make a small helper function personally, if i planned on doing it often;
from matplotlib import pyplot
import numpy
a = numpy.array([[ 3.57, 1.76, 7.42, 6.52],
[ 1.57, 1.2 , 3.02, 6.88],
[ 2.23, 4.86, 5.12, 2.81],
[ 4.48, 1.38, 2.14, 0.86],
[ 6.68, 1.72, 8.56, 3.23]])
def plotCollection(ax, xs, ys, *args, **kwargs):
ax.plot(xs,ys, *args, **kwargs)
if "label" in kwargs.keys():
#remove duplicates
handles, labels = pyplot.gca().get_legend_handles_labels()
newLabels, newHandles = [], []
for handle, label in zip(handles, labels):
if label not in newLabels:
newLabels.append(label)
newHandles.append(handle)
pyplot.legend(newHandles, newLabels)
ax = pyplot.subplot(1,1,1)
plotCollection(ax, a[:,::2].T, a[:, 1::2].T, 'r', label='data_a')
plotCollection(ax, a[:,1::2].T, a[:, ::2].T, 'b', label='data_b')
pyplot.show()
An easier (and IMO clearer) way to remove duplicates (than what you have) from the handles
and labels
of the legend is this:
handles, labels = pyplot.gca().get_legend_handles_labels()
newLabels, newHandles = [], []
for handle, label in zip(handles, labels):
if label not in newLabels:
newLabels.append(label)
newHandles.append(handle)
pyplot.legend(newHandles, newLabels)
Matplotlib gives you a nice interface to collections of lines, LineCollection. The code is straight forward
import numpy
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
a = numpy.array([[ 3.57, 1.76, 7.42, 6.52],
[ 1.57, 1.2 , 3.02, 6.88],
[ 2.23, 4.86, 5.12, 2.81],
[ 4.48, 1.38, 2.14, 0.86],
[ 6.68, 1.72, 8.56, 3.23]])
xs = a[:,::2]
ys = a[:, 1::2]
lines = LineCollection([list(zip(x,y)) for x,y in zip(xs, ys)], label='data_a')
f, ax = plt.subplots(1, 1)
ax.add_collection(lines)
ax.legend()
ax.set_xlim([xs.min(), xs.max()]) # have to set manually
ax.set_ylim([ys.min(), ys.max()])
plt.show()
Numpy solution based on will’s response above.
import numpy as np
import matplotlib.pylab as plt
a = np.array([[3.57, 1.76, 7.42, 6.52],
[1.57, 1.20, 3.02, 6.88],
[2.23, 4.86, 5.12, 2.81],
[4.48, 1.38, 2.14, 0.86],
[6.68, 1.72, 8.56, 3.23]])
plt.plot(a[:,::2].T, a[:, 1::2].T, 'r', label='data_a')
handles, labels = plt.gca().get_legend_handles_labels()
Assuming that equal labels have equal handles, get unique labels and their respective indices, which correspond to handle indices.
labels, ids = np.unique(labels, return_index=True)
handles = [handles[i] for i in ids]
plt.legend(handles, labels, loc='best')
plt.show()
A low tech solution is to make two plot calls. One that plots your data and a second one that plots nothing but carries the handle:
a = np.array([[ 3.57, 1.76, 7.42, 6.52],
[ 1.57, 1.2 , 3.02, 6.88],
[ 2.23, 4.86, 5.12, 2.81],
[ 4.48, 1.38, 2.14, 0.86],
[ 6.68, 1.72, 8.56, 3.23]])
plt.plot(a[:,::2].T, a[:, 1::2].T, 'r')
plt.plot([],[], 'r', label='data_a')
plt.legend(loc='best')
Here’s the result:
I would do this trick:
for i in range(len(a)):
plt.plot(a[i,::2].T, a[i, 1::2].T, 'r', label='data_a' if i==0 else None)
The easiest and most pythonic way to remove duplicates is to use the keys of a dict which are guaranteed to be unique. This also ensures that we only iterate over each of the (handle, label) pairs once.
handles, labels = plt.gca().get_legend_handles_labels()
# labels will be the keys of the dict, handles will be values
temp = {k:v for k,v in zip(labels, handles)}
plt.legend(temp.values(), temp.keys(), loc='best')
I found short way to solve this:
a = np.array([[ 3.57, 1.76, 7.42, 6.52],
[ 1.57, 1.2 , 3.02, 6.88],
[ 2.23, 4.86, 5.12, 2.81],
[ 4.48, 1.38, 2.14, 0.86],
[ 6.68, 1.72, 8.56, 3.23]])
p1=plt.plot(a[:,::2].T, a[:, 1::2].T, color='r')
plt.legend([p1[0]],['data_a'],loc='best')