matplotlib legend not showing correctly

Question:

I am trying to plot some data from a csv file. I used the Pandas to load the csv file. I am using sns.lineplot() to plot the lines. But one of the legend is always faulty. It shows a square around one of the legend.

plt.figure(dpi=150)
lin1 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="tanh"])
lin2 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="silu"])
lin3 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="swish"])
plt.xlabel("Training time  (sec)")
plt.legend(("tanh", "silu", "swish"))
plt.yscale('log',base=10)

I used 3 different functions because there are more Activations. This is the resulting plot.

enter image description here

The plot is looking correct but the legend is creating problems. Here are versions of the plotting tools that I am using.

Python 3.9.12
matplotlib                    3.6.1
matplotlib-inline             0.1.6
seaborn                       0.12.1

I could not find the same issue on Internet. A kernel restart isn’t helping. Please let me know if more information is needed.

Asked By: Prakhar Sharma

||

Answers:

Try adding the label argument to the individual lineplots and then just call legend without any arguments:

plt.figure(dpi=150)
lin1 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="tanh"], label = 'tanh')
lin2 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="silu"], label = 'silu')
lin3 = sns.lineplot(x = "Training time", y = "Relative L2 error", data=df[df["Activation"]=="swish"], label = 'swish')
plt.xlabel("Training time  (sec)")
plt.legend()
plt.yscale('log',base=10)
Answered By: Michael Cao

You can also plot all your lines with a single command by using hue:

sns.lineplot(x="Training time", y="Relative L2 error", data=df[df["Activation"].isin(["tanh", "silu", "swish"])], hue="Activation")

Edit: as @JohanC cleverly suggested, you could use hue_order to get a slightly more compact expression:

sns.lineplot(x="Training time", y="Relative L2 error", data=df, hue="Activation", hue_order=["tanh", "silu", "swish"])
Answered By: Tranbi
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.