Keep certain columns in a pandas DataFrame, deleting everything else

Question:

Say I have a data table

    1  2  3  4  5  6 ..  n
A   x  x  x  x  x  x ..  x
B   x  x  x  x  x  x ..  x
C   x  x  x  x  x  x ..  x

And I want to slim it down so that I only have, say, columns 3 and 5 deleting all other and maintaining the structure. How could I do this with pandas? I think I understand how to delete a single column, but I don’t know how to save a select few and delete all others.

Asked By: Matt

||

Answers:

You could reassign a new value to your DataFrame, df:

df = df.loc[:,[3, 5]]

As long as there are no other references to the original DataFrame, the old DataFrame will get garbage collected.

Note that when using df.loc, the index is specified by labels. Thus above 3 and 5 are not ordinals, they represent the label names of the columns. If you wish to specify the columns by ordinal index, use df.iloc.

Answered By: unutbu

If you have a list of columns you can just select those:

In [11]: df
Out[11]:
   1  2  3  4  5  6
A  x  x  x  x  x  x
B  x  x  x  x  x  x
C  x  x  x  x  x  x

In [12]: col_list = [3, 5]

In [13]: df = df[col_list]

In [14]: df
Out[14]:
   3  5
A  x  x
B  x  x
C  x  x
Answered By: Andy Hayden

For those who are searching an method to do this inplace:

from pandas import DataFrame
from typing import Set, Any
def remove_others(df: DataFrame, columns: Set[Any]):
    cols_total: Set[Any] = set(df.columns)
    diff: Set[Any] = cols_total - columns
    df.drop(diff, axis=1, inplace=True)

This will create the complement of all the columns in the dataframe and the columns which should be removed. Those can safely be removed. Drop works even on an empty set.

>>> df = DataFrame({"a":[1,2,3],"b":[2,3,4],"c":[3,4,5]})
>>> df
   a  b  c
0  1  2  3
1  2  3  4
2  3  4  5

>>> remove_others(df, {"a","b","c"})
>>> df
   a  b  c
0  1  2  3
1  2  3  4
2  3  4  5

>>> remove_others(df, {"a"})
>>> df
   a
0  1
1  2
2  3

>>> remove_others(df, {"a","not","existent"})
>>> df
   a
0  1
1  2
2  3
Answered By: Fjolnir Dvorak

How do I keep certain columns in a pandas DataFrame, deleting everything else?

The answer to this question is the same as the answer to “How do I delete certain columns in a pandas DataFrame?” Here are some additional options to those mentioned so far, along with timings.

DataFrame.loc

One simple option is selection, as mentioned by in other answers,

# Setup.
df
   1  2  3  4  5  6
A  x  x  x  x  x  x
B  x  x  x  x  x  x
C  x  x  x  x  x  x

cols_to_keep = [3,5]

df[cols_to_keep]

   3  5
A  x  x
B  x  x
C  x  x

Or,

df.loc[:, cols_to_keep]

   3  5
A  x  x
B  x  x
C  x  x

DataFrame.reindex with axis=1 or 'columns' (0.21+)

However, we also have reindex, in recent versions you specify axis=1 to drop:

df.reindex(cols_to_keep, axis=1)
# df.reindex(cols_to_keep, axis='columns')

# for versions < 0.21, use
# df.reindex(columns=cols_to_keep)

   3  5
A  x  x
B  x  x
C  x  x

On older versions, you can also use reindex_axis: df.reindex_axis(cols_to_keep, axis=1).


DataFrame.drop

Another alternative is to use drop to select columns by pd.Index.difference:

# df.drop(cols_to_drop, axis=1)
df.drop(df.columns.difference(cols_to_keep), axis=1)

   3  5
A  x  x
B  x  x
C  x  x

Performance

enter image description here

The methods are roughly the same in terms of performance; reindex is faster for smaller N, while drop is faster for larger N. The performance is relative as the Y-axis is logarithmic.

Setup and Code

import pandas as pd
import perfplot

def make_sample(n):
    np.random.seed(0)
    df = pd.DataFrame(np.full((n, n), 'x'))
    cols_to_keep = np.random.choice(df.columns, max(2, n // 4), replace=False)

    return df, cols_to_keep 

perfplot.show(
    setup=lambda n: make_sample(n),
    kernels=[
        lambda inp: inp[0][inp[1]],
        lambda inp: inp[0].loc[:, inp[1]],
        lambda inp: inp[0].reindex(inp[1], axis=1),
        lambda inp: inp[0].drop(inp[0].columns.difference(inp[1]), axis=1)
    ],
    labels=['__getitem__', 'loc', 'reindex', 'drop'],
    n_range=[2**k for k in range(2, 13)],
    xlabel='N',   
    logy=True,
    equality_check=lambda x, y: (x.reindex_like(y) == y).values.all()
)
Answered By: cs95

Another approach is to use filter:

In [5]: df.filter([3, 5])
Out[5]: 
   3  5
A  x  x
B  x  x
C  x  x
Answered By: rachwa
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.