Pandas groupby: get max value in a subgroup

Question:

I have a large dataset grouped by column, row, year, potveg, and total. I am trying to get the max value of ‘total’ column in a specific year of a group. i.e., for the dataset below:

col      row    year    potveg  total

-125.0  42.5    2015    9       697.3
                2015    13      535.2
                2015    15      82.3
                2016    9       907.8
                2016    13      137.6
                2016    15      268.4
                2017    9       961.9
                2017    13      74.2
                2017    15      248.0
                2018    9       937.9
                2018    13      575.6
                2018    15      215.5
-135.0  70.5    2015    8       697.3
                2015    10      535.2
                2015    19      82.3
                2016    8       907.8
                2016    10      137.6
                2016    19      268.4
                2017    8       961.9
                2017    10      74.2
                2017    19      248.0
                2018    8       937.9
                2018    10      575.6
                2018    19      215.5

I would like the output to look like this:

col      row    year    potveg  total

-125.0  42.5    2015    9       697.3
                2016    9       907.8
                2017    9       961.9
                2018    9       937.9
-135.0  70.5    2015    8       697.3
                2016    8       907.8
                2017    8       961.9
                2018    8       937.9

I tried this:

df.groupby(['col', 'row', 'year', 'potveg']).agg({'total': 'max'})

and this:

df.groupby(['col', 'row', 'year', 'potveg'])['total'].max()

but they do not seem to work because the output has too many rows.
I think the issue is the ‘potveg’ column which is a subgroup. I am not sure how to select rows containing max value of ‘total’.

Asked By: Jared Kodero

||

Answers:

One possible solution, using .idxmax() inside groupby.apply:

print(
    df.groupby(["col", "row", "year"], as_index=False, sort=False).apply(
        lambda x: x.loc[x["total"].idxmax()]
    )
)

Prints:

     col   row    year  potveg  total
0 -125.0  42.5  2015.0     9.0  697.3
1 -125.0  42.5  2016.0     9.0  907.8
2 -125.0  42.5  2017.0     9.0  961.9
3 -125.0  42.5  2018.0     9.0  937.9
4 -135.0  70.5  2015.0     8.0  697.3
5 -135.0  70.5  2016.0     8.0  907.8
6 -135.0  70.5  2017.0     8.0  961.9
7 -135.0  70.5  2018.0     8.0  937.9

DataFrame used:

       col   row  year potveg  total
0   -125.0  42.5  2015      9  697.3
1   -125.0  42.5  2015     13  535.2
2   -125.0  42.5  2015     15   82.3
3   -125.0  42.5  2016      9  907.8
4   -125.0  42.5  2016     13  137.6
5   -125.0  42.5  2016     15  268.4
6   -125.0  42.5  2017      9  961.9
7   -125.0  42.5  2017     13   74.2
8   -125.0  42.5  2017     15  248.0
9   -125.0  42.5  2018      9  937.9
10  -125.0  42.5  2018     13  575.6
11  -125.0  42.5  2018     15  215.5
12  -135.0  70.5  2015      8  697.3
13  -135.0  70.5  2015     10  535.2
14  -135.0  70.5  2015     19   82.3
15  -135.0  70.5  2016      8  907.8
16  -135.0  70.5  2016     10  137.6
17  -135.0  70.5  2016     19  268.4
18  -135.0  70.5  2017      8  961.9
19  -135.0  70.5  2017     10   74.2
20  -135.0  70.5  2017     19  248.0
21  -135.0  70.5  2018      8  937.9
22  -135.0  70.5  2018     10  575.6
23  -135.0  70.5  2018     19  215.5
Answered By: Andrej Kesely

Option 1: One way is the do the groupby() and then merge with the original df

df1 = pd.merge(df.groupby(['col','row','year']).agg({'total':'max'}).reset_index(), 
               df, 
               on=['col', 'row', 'year', 'total'])
print(df1)

Output:

      col   row  year  total potveg
0  -125.0  42.5  2015  697.3      9
1  -125.0  42.5  2016  907.8      9
2  -125.0  42.5  2017  961.9      9
3  -125.0  42.5  2018  937.9      9
4  -135.0  70.5  2015  697.3      8
5  -135.0  70.5  2016  907.8      8
6  -135.0  70.5  2017  961.9      8
7  -135.0  70.5  2018  937.9      8

Option 2: Or the use of sort_values() and drop_duplicates() like this:

df1 = df.sort_values(['col','row','year']).drop_duplicates(['col','row','year'], keep='first')
print(df1)

Output:

       col   row  year potveg  total
0   -125.0  42.5  2015      9  697.3
3   -125.0  42.5  2016      9  907.8
6   -125.0  42.5  2017      9  961.9
9   -125.0  42.5  2018      9  937.9
12  -135.0  70.5  2015      8  697.3
15  -135.0  70.5  2016      8  907.8
18  -135.0  70.5  2017      8  961.9
21  -135.0  70.5  2018      8  937.9
Answered By: perpetualstudent
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.