pandas groupby().head(n) where n is a function of group label

Question:

I have a dataframe, and I would like to group by a column and take the head of each group, but I want the depth of the head to be defined by a function of the group label. If it weren’t for the variable group sizes, I could easily do df.groupby('label').head(n). I can imagine a solution that involves iterating through df['label'].unique(), slicing the dataframe and building a new one, but I’m in a context where I’m pretty sensitive to performance so I’d like to avoid that kind of iteration if possible.

Here’s an exmaple dataframe:

  label   values
0  apple       7
1  apple       5
2  apple       4
3    car       9
4    car       6
5    dog       5
6    dog       3
7    dog       2
8    dog       1

and code for my example setup:

import pandas as pd
df = pd.DataFrame({'label': ['apple', 'apple', 'apple', 'car', 'car', 'dog', 'dog', 'dog', 'dog'],
          'values': [7, 5, 4, 9, 6, 5, 3, 2 ,1]})
def depth(label):
    if label == 'apple': return 1
    elif label == 'car': return 2
    elif label == 'dog': return 3

my desired output is a dataframe with the number of rows from each group defined by that function:

   label  values
0  apple       7
3    car       9
4    car       6
5    dog       5
6    dog       3
7    dog       2
Asked By: Jacob H

||

Answers:

I would use a dictionary here and using <group>.name in groupby.apply:

depth = {'apple': 1, 'car': 2, 'dog': 3}

out = (df.groupby('label', group_keys=False)
         .apply(lambda g: g.head(depth.get(g.name, 0)))
       )

NB. if you really need a function, you can do the same with a function call. Make sure to return a value in every case.

Alternative option with groupby.cumcount and boolean indexing:

out = df[df['label'].map(depth).gt(df.groupby('label').cumcount())]

output:

   label  values
0  apple       7
3    car       9
4    car       6
5    dog       5
6    dog       3
7    dog       2
Answered By: mozway

Another possible solution, based on GroupBy.get_group, groupby.ngroups and groups.keys:

g = df.groupby('label')
pd.concat([g.get_group(x[0]).head(x[1]+1)
          for x in zip(g.groups.keys(), range(g.ngroups))])

Output:

   label  values
0  apple       7
3    car       9
4    car       6
5    dog       5
6    dog       3
7    dog       2
Answered By: PaulS
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.