How do I get the index of a row where max value is not duplicated?

Question:

Given this df

from io import StringIO
import pandas as pd

data = StringIO('''gene_variant gene    val1    val2    val3
b1  b   1   1   1
b2  b   2   11  1
b3  b   3   11  1
c2  c   1   1   1
t1  t   1   1   1
t2  t   12  2   2
t4  t   12  3   2
t5  t   1   4   3
d2  d   11  1   2
d4  d   11  1   1''')
df = pd.read_csv(data, sep='t')

How do I get the gene_variant for each gene where; the gene_variant corresponds to the max value for val1 if the max value is not duplicated, and if it is duplicated, the gene_variant corresponds to the max value for val2 if the max value for val2 is not duplicated, or then just max for val3? I.e., any tiebreakers are decided by the next column until the third option.

EDIT: The column val2 is only considered if the max val in val1 is a duplicated/a tie. Same for val3. If the max val in val1/2 is duplicated/is a tie then the values in those columns are no longer considered. Only the values in 1 column at a time are compared.

I’ve been trying solutions based on:

df.groupby('gene').agg(max)

and:

df.groupby('gene').rank('max')

But I can’t get there without dropping out into iteration…

The correct answer would be:

b3 3
c2 1
t5 4
d2 2

Thanks in advance!

Asked By: Liam McIntyre

||

Answers:

If need maximum only for groups with no duplicated values is possible use:

#per groups count number of unique values
df1 = df.groupby('gene').transform('nunique')

#compare columns with `gene_variant` and set NaN if duplicates per columns
 
#if maximum is count from all columns if not duplicated values get max
max1 = df.where(df1.eq(df1.pop('gene_variant'), axis=0)).max(axis=1)

#if max is count by order - first val1, then val2
#back filling missing values and select first column
max1 = df.where(df1.eq(df1.pop('gene_variant'), axis=0)).bfill(axis=1).iloc[:, 0]

#assign column by maximum
df = df.assign(max1 = max1)

#get rows from original with maximum max1 per groups
df = df.loc[df.groupby('gene', sort=False)['max1'].idxmax(), ['gene_variant','max1']]
print (df)
  gene_variant  max1
2           b3   3.0
3           c2   1.0
7           t5   4.0
8           d2   2.0

How it working:

df1 = df.groupby('gene').transform('nunique')

s = df1.pop('gene_variant')

print (df.where(df1.eq(s, axis=0)))
  gene_variant gene  val1  val2  val3
0          NaN  NaN   1.0   NaN   NaN
1          NaN  NaN   2.0   NaN   NaN
2          NaN  NaN   3.0   NaN   NaN
3          NaN  NaN   1.0   1.0   1.0
4          NaN  NaN   NaN   1.0   NaN
5          NaN  NaN   NaN   2.0   NaN
6          NaN  NaN   NaN   3.0   NaN
7          NaN  NaN   NaN   4.0   NaN
8          NaN  NaN   NaN   NaN   2.0
9          NaN  NaN   NaN   NaN   1.0

#max of all columns
print (df.where(df1.eq(s, axis=0)).max(axis=1))
0    1.0
1    2.0
2    3.0
3    1.0
4    1.0
5    2.0
6    3.0
7    4.0
8    2.0
9    1.0
dtype: float64

#back fill NaNs
print (df.where(df1.eq(s, axis=0)).bfill(axis=1))
   gene_variant  gene  val1  val2  val3
0           1.0   1.0   1.0   NaN   NaN
1           2.0   2.0   2.0   NaN   NaN
2           3.0   3.0   3.0   NaN   NaN
3           1.0   1.0   1.0   1.0   1.0
4           1.0   1.0   1.0   1.0   NaN
5           2.0   2.0   2.0   2.0   NaN
6           3.0   3.0   3.0   3.0   NaN
7           4.0   4.0   4.0   4.0   NaN
8           2.0   2.0   2.0   2.0   2.0
9           1.0   1.0   1.0   1.0   1.0

#selected first column
print (df.where(df1.eq(s, axis=0)).bfill(axis=1).iloc[:, 0])
0    1.0
1    2.0
2    3.0
3    1.0
4    1.0
5    2.0
6    3.0
7    4.0
8    2.0
9    1.0
Name: gene_variant, dtype: float64
Answered By: jezrael

You could use .sort_values() to get the maximum values. If you pass it multiple columns, it will treat tiebrakers correctly.

In [9]: df.sort_values(["val1", "val2", "val3"])
Out[9]: 
  gene_variant gene  val1  val2  val3
0           b1    b     1     1     1
3           c2    c     1     1     1
4           t1    t     1     1     1
9           d4    d     1     1     1
8           d2    d     1     1     2
7           t5    t     1     4     3
1           b2    b     2     1     1
5           t2    t     2     2     2
6           t4    t     2     3     2
2           b3    b     3     1     1

Now, in order to do this for each gene you can groupby('gene') and apply a custom function.

In [11]: df.groupby("gene").apply(
    ...:     lambda _df: _df.sort_values(["val1", "val2", "val3"], ascending=False)
    ...:     .head(1)
    ...:     .squeeze()
    ...: )
Out[11]: 
     gene_variant gene  val1  val2  val3
gene                                    
b              b3    b     3     1     1
c              c2    c     1     1     1
d              d2    d     1     1     2
t              t4    t     2     3     2

However, this is not telling you which val it was that won the tiebraker.

Answered By: maow