How to sklearn random split test/train set by class and id?

Question:

I would like to divide the set into training and test in a 50:50 ratio according to the class ‘fruit’. However, so that classes with the same ID go into either the training or test set.

Here is an example data:

import pandas as pd
import random
from sklearn.model_selection import GroupShuffleSplit
    
df = pd.DataFrame({'fruit': ['watermelon', 'watermelon', 'watermelon', 'watermelon', 'watermelon', 
'apple', 'apple', 'apple', 'apple', 'apple', 'apple', 'apple', 
"lemon", "lemon"], 
'ID': [1, 1, 1, 2, 2, 3, 4, 4, 5, 6, 6, 6 , 7 ,8], 
'value1': random.sample(range(10, 100), 14), 
'value2': random.sample(range(10, 100), 14) })

I try:

X = df[['value1', 'value2']]
y = df['fruit']
groups = df['ID']
gss = GroupShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
train_idx, test_idx = next(gss.split(X, y, groups))
X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

Therefore, for example the watermelon class: three rows will go into the training set (with ID = 1) and two rows will go into the test set (with ID = 2). And same with apple and lemon. However, it divides the set badly that, for example, a class of lemons goes into training or testing and there should be 1 line each in this and that.

Asked By: Nicolas

||

Answers:

I don’t think you can use stratify and groups simultaneously in sklearn.
One solution is to do the split for each fruit category separately and then regroup the elements. This works fine if there are not too many fruit categories, but the total split might not be exactly 50/50 due to odd splits on some fruit categories.

You can try:

X = df[['value1', 'value2']]
y = df['fruit']
groups = df['ID']
gss = GroupShuffleSplit(n_splits=1, test_size=0.5, random_state=0)
train_idx, test_idx = [], []

for fruit in df['fruit'].unique():
    train_idx_fruit, test_idx_fruit = next(gss.split(X[y==fruit], y[y==fruit], groups[y==fruit]))
    train_idx_fruit = X[y==fruit].index[train_idx_fruit] # setting to the initial index
    test_idx_fruit = X[y==fruit].index[test_idx_fruit]

    train_idx += list(train_idx_fruit)
    test_idx += list(test_idx_fruit)

X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

There is probably a cleaner way using groupby instead.

Answered By: Mattravel
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.