Get train/valid/test index with sequence from pandas dataframe

Question:

I want to make train/valid/test column, last sequence to test and the previous sequence to valid

+----+-------+----+------+
|user|cnt_seq|item| mode |
+----+-------+----+------+
|   1|      1|   4| train|
|   1|      1|   7| train|
|   1|      2|   2| train|
|   1|      2|   9| train|
|   1|      3|   8| valid|
|   1|      4|   3|  test|
|   1|      4|  10|  test|
|   2|      1|   6| train|
|   2|      2|   7| valid|
|   2|      3|   1|  test|
+----+-------+----+------+

Each user has different length of cnt_seq and number of cnt_seq.
So my code is…

test_users = [1, 2]
mdict = df.groupby('user')['cnt_seq'].max().to_dict()
test_idx = [(k, v) for k, v in mdict.items() if k in test_users]
valid_idx = [(k, v-1) for k, v in mdict.items() if k in test_users] 

df['mode'] = 'train'

for i, j in valid_idx:
    df.loc[(df.user== i) & (df.cnt_seq == j), 'mode'] = 'valid'
for i, j in test_idx:
    df.loc[(df.user== i) & (df.cnt_seq == j), 'mode'] = 'test'

But I think it’s not good since it needs two for loops for valid/test.
Could I get more simple code for it?

Asked By: Dang

||

Answers:

try this:

import pandas as pd
import numpy as np

# Your original df
data = [{'user': 1, 'cnt_seq': 1, 'item': 4},
        {'user': 1, 'cnt_seq': 1, 'item': 7},
        {'user': 1, 'cnt_seq': 2, 'item': 2},
        {'user': 1, 'cnt_seq': 2, 'item': 9},
        {'user': 1, 'cnt_seq': 3, 'item': 8},
        {'user': 1, 'cnt_seq': 4, 'item': 3},
        {'user': 1, 'cnt_seq': 4, 'item': 10},
        {'user': 2, 'cnt_seq': 1, 'item': 6},
        {'user': 2, 'cnt_seq': 2, 'item': 7},
        {'user': 2, 'cnt_seq': 3, 'item': 1}]
df = pd.DataFrame(data)

# Calculate the maximum sequence number for each user
group_max = df.groupby(['user'])['cnt_seq'].transform('max')

# Assign modes to each sequence based on the maximum sequence number
df['mode'] = np.select(
    [
        df['cnt_seq'] == group_max,          # test set corresponds to the last sequence
        df['cnt_seq'] == group_max-1         # validation set corresponds to the previous sequence
    ],
    ['test', 'valid'],                       # corresponding modes
    'train'                                   # all other sequences are assigned train set
)

# Print the results
print(df)
Answered By: ziying35
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.