How to to graph multiple lines using sns.scatterplot

Question:

I have written a program like so:

# Author: Evan Gertis
# Date  : 11/09
# program: Linear Regression
# Resource: https://seaborn.pydata.org/generated/seaborn.scatterplot.html       
import seaborn as sns
import pandas as pd
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Step 1: load the data
grades = pd.read_csv("grades.csv") 
logging.info(grades.head())

# Step 2: plot the data
plot = sns.scatterplot(data=grades, x="Hours", y="GPA")
fig = plot.get_figure()
fig.savefig("out.png")

Using the data set

Hours,GPA,Hours,GPA,Hours,GPA
11,2.84,9,2.85,25,1.85
5,3.20,5,3.35,6,3.14
22,2.18,14,2.60,9,2.96
23,2.12,18,2.35,20,2.30
20,2.55,6,3.14,14,2.66
20,2.24,9,3.05,19,2.36
10,2.90,24,2.06,21,2.24
19,2.36,25,2.00,7,3.08
15,2.60,12,2.78,11,2.84
18,2.42,6,2.90,20,2.45

I would like to plot out all of the relationships at this time I just get one plot:

enter image description here

Expected:
all relationships plotted

Actual:

enter image description here

I wrote a basic program and I was expecting all of the relationships to be plotted.

Asked By: Evan Gertis

||

Answers:

The origin of the problem is that the columns names in your file are the same and thus when pandas read the columns adds number to the loaded data frame

import seaborn as sns
import pandas as pd
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

grades = pd.read_csv("grades.csv") 
print(grades.columns)
>>> Index(['Hours', 'GPA', 'Hours.1', 'GPA.1', 'Hours.2', 'GPA.2'], dtype='object')

therefore when you plot the scatter plot you need to give the name of the column names that pandas give

# in case you want all scatter plots in the same figure
plot = sns.scatterplot(data=grades, x="Hours", y="GPA", label='GPA')
sns.scatterplot(data=grades, x='Hours.1', y='GPA.1', ax=plot, label="GPA.1")
sns.scatterplot(data=grades, x='Hours.2', y='GPA.2', ax=plot,  label='GPA.2')
fig = plot.get_figure()
fig.savefig("out.png")

enter image description here

Answered By: Lucas M. Uriarte
  • There are better options than manually creating a plot for each group of columns
  • Because the columns in the file have redundant names, pandas automatically renames them.

Imports and DataFrame

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# read the data from the file
df = pd.read_csv('d:/data/gpa.csv')

# display(df)
   Hours   GPA  Hours.1  GPA.1  Hours.2  GPA.2
0     11  2.84        9   2.85       25   1.85
1      5  3.20        5   3.35        6   3.14
2     22  2.18       14   2.60        9   2.96
3     23  2.12       18   2.35       20   2.30
4     20  2.55        6   3.14       14   2.66
5     20  2.24        9   3.05       19   2.36
6     10  2.90       24   2.06       21   2.24
7     19  2.36       25   2.00        7   3.08
8     15  2.60       12   2.78       11   2.84
9     18  2.42        6   2.90       20   2.45

Option 1: Chunk the column names

  • This option can be used to plot the data in a loop without manually creating each plot
  • Using this answer from How to iterate over a list in chunks will create a list of column name groups:
    • [Index(['Hours', 'GPA'], dtype='object'), Index(['Hours.1', 'GPA.1'], dtype='object'), Index(['Hours.2', 'GPA.2'], dtype='object')]
# create groups of column names to be plotted together
def chunker(seq, size):
    return [seq[pos:pos + size] for pos in range(0, len(seq), size)]


# function call
col_list = chunker(df.columns, 2)

# iterate through each group of column names to plot
for x, y in chunker(df.columns, 2):
    sns.scatterplot(data=df, x=x, y=y, label=y)

Option 2: Fix the data

# filter each group of columns, melt the result into a long form, and get the value
h = df.filter(like='Hours').melt().value
g = df.filter(like='GPA').melt().value

# get the gpa column names
gpa_cols = df.columns[1::2]

# use numpy to create a list of labels with the appropriate length
labels = np.repeat(gpa_cols, len(df))

# otherwise use a list comprehension to create the labels
# labels = [v for x in gpa_cols for v in [x]*len(df)]

# create a new dataframe
dfl = pd.DataFrame({'hours': h, 'gpa': g, 'label': labels})

# save dfl if desired
dfl.to_csv('gpa_long.csv', index=False)

# plot
sns.scatterplot(data=dfl, x='hours', y='gpa', hue='label')

Plot Result

enter image description here

Answered By: Trenton McKinney
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.