Creating folds properly

Reason is the light and the light of life.

Jerry Su Feb 27, 2022 5 mins
!pip install -q iterative-stratification
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
import pandas as pd
from sklearn.model_selection import KFold
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold

df = pd.read_csv('/root/.cache/data/train.csv')
dfx = pd.get_dummies(df, columns=["discourse_type"]).head(5)
dfx
id discourse_id discourse_start discourse_end discourse_text discourse_type_num predictionstring discourse_type_Claim discourse_type_Concluding Statement discourse_type_Counterclaim discourse_type_Evidence discourse_type_Lead discourse_type_Position discourse_type_Rebuttal
0 423A1CA112E2 1.622628e+12 8.0 229.0 Modern humans today are always on their phone…. Lead 1 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 1… 0 0 0 0 1 0 0
1 423A1CA112E2 1.622628e+12 230.0 312.0 They are some really bad consequences when stu… Position 1 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 0 0 0 0 0 1 0
2 423A1CA112E2 1.622628e+12 313.0 401.0 Some certain areas in the United States ban ph… Evidence 1 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 0 0 0 1 0 0 0
3 423A1CA112E2 1.622628e+12 402.0 758.0 When people have phones, they know about certa… Evidence 2 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 9… 0 0 0 1 0 0 0
4 423A1CA112E2 1.622628e+12 759.0 886.0 Driving is one of the way how to get around. P… Claim 1 139 140 141 142 143 144 145 146 147 148 149 15… 1 0 0 0 0 0 0
dfx = pd.get_dummies(df, columns=["discourse_type"]).groupby(["id"], as_index=False).sum()
dfx
id discourse_id discourse_start discourse_end discourse_type_Claim discourse_type_Concluding Statement discourse_type_Counterclaim discourse_type_Evidence discourse_type_Lead discourse_type_Position discourse_type_Rebuttal
0 0000D23A521A 1.294188e+13 4166.0 5506.0 1 1 1 3 0 1 1
1 00066EA9880D 1.458994e+13 12618.0 16058.0 3 1 0 3 1 1 0
2 000E6DE9E817 1.940756e+13 8760.0 10092.0 5 1 1 3 0 1 1
3 001552828BD0 1.622844e+13 12881.0 15580.0 4 0 0 4 1 1 0
4 0016926B079C 1.783190e+13 5102.0 6414.0 7 0 0 3 0 1 0
15589 FFF1442D6698 1.618644e+13 14374.0 17948.0 2 1 1 3 1 1 1
15590 FFF1ED4F8544 1.454313e+13 6944.0 9435.0 5 0 0 2 1 1 0
15591 FFF868E06176 1.456920e+13 8210.0 10507.0 3 1 0 3 1 1 0
15592 FFFD0AF13501 1.295859e+13 4408.0 5395.0 4 1 0 2 0 1 0
15593 FFFF80B8CC2F 1.617042e+12 0.0 990.0 0 0 0 1 0 0 0

15594 rows × 11 columns

dfx.columns
Index(['id', 'discourse_id', 'discourse_start', 'discourse_end',
       'discourse_type_Claim', 'discourse_type_Concluding Statement',
       'discourse_type_Counterclaim', 'discourse_type_Evidence',
       'discourse_type_Lead', 'discourse_type_Position',
       'discourse_type_Rebuttal'],
      dtype='object')
cols = [c for c in dfx.columns if c.startswith("discourse_type") or c =="id" and c !=  "discourse_type_num"]
cols
['id',
 'discourse_type_Claim',
 'discourse_type_Concluding Statement',
 'discourse_type_Counterclaim',
 'discourse_type_Evidence',
 'discourse_type_Lead',
 'discourse_type_Position',
 'discourse_type_Rebuttal']
dfx = dfx[cols]
dfx
id discourse_type_Claim discourse_type_Concluding Statement discourse_type_Counterclaim discourse_type_Evidence discourse_type_Lead discourse_type_Position discourse_type_Rebuttal
0 0000D23A521A 1 1 1 3 0 1 1
1 00066EA9880D 3 1 0 3 1 1 0
2 000E6DE9E817 5 1 1 3 0 1 1
3 001552828BD0 4 0 0 4 1 1 0
4 0016926B079C 7 0 0 3 0 1 0
15589 FFF1442D6698 2 1 1 3 1 1 1
15590 FFF1ED4F8544 5 0 0 2 1 1 0
15591 FFF868E06176 3 1 0 3 1 1 0
15592 FFFD0AF13501 4 1 0 2 0 1 0
15593 FFFF80B8CC2F 0 0 0 1 0 0 0

15594 rows × 8 columns

mskf = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=42)
labels = [c for c in dfx.columns if c != "id"]
dfx_labels = dfx[labels]
dfx_labels
discourse_type_Claim discourse_type_Concluding Statement discourse_type_Counterclaim discourse_type_Evidence discourse_type_Lead discourse_type_Position discourse_type_Rebuttal
0 1 1 1 3 0 1 1
1 3 1 0 3 1 1 0
2 5 1 1 3 0 1 1
3 4 0 0 4 1 1 0
4 7 0 0 3 0 1 0
15589 2 1 1 3 1 1 1
15590 5 0 0 2 1 1 0
15591 3 1 0 3 1 1 0
15592 4 1 0 2 0 1 0
15593 0 0 0 1 0 0 0

15594 rows × 7 columns

dfx["kfold"] = -1
/tmp/ipykernel_27806/3539168384.py:1: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dfx["kfold"] = -1
df = pd.read_csv('/root/.cache/data/train.csv')

dfx = pd.get_dummies(df, columns=["discourse_type"]).groupby(["id"], as_index=False).sum()
cols = [c for c in dfx.columns if c.startswith("discourse_type_") or c == "id" and c != "discourse_type_num"]
dfx = dfx[cols]

mskf = MultilabelStratifiedKFold(n_splits=10, shuffle=True, random_state=42)
labels = [c for c in dfx.columns if c != "id"]
dfx_labels = dfx[labels]
dfx["kfold"] = -1
dfx
id discourse_type_Claim discourse_type_Concluding Statement discourse_type_Counterclaim discourse_type_Evidence discourse_type_Lead discourse_type_Position discourse_type_Rebuttal kfold
0 0000D23A521A 1 1 1 3 0 1 1 -1
1 00066EA9880D 3 1 0 3 1 1 0 -1
2 000E6DE9E817 5 1 1 3 0 1 1 -1
3 001552828BD0 4 0 0 4 1 1 0 -1
4 0016926B079C 7 0 0 3 0 1 0 -1
15589 FFF1442D6698 2 1 1 3 1 1 1 -1
15590 FFF1ED4F8544 5 0 0 2 1 1 0 -1
15591 FFF868E06176 3 1 0 3 1 1 0 -1
15592 FFFD0AF13501 4 1 0 2 0 1 0 -1
15593 FFFF80B8CC2F 0 0 0 1 0 0 0 -1

15594 rows × 9 columns

for fold, (trn_, val_) in enumerate(mskf.split(dfx, dfx_labels)):
    print(len(trn_), len(val_))
    dfx.loc[val_, "kfold"] = fold

df = df.merge(dfx[["id", "kfold"]], on="id", how="left")
print(df.kfold.value_counts())
# df.to_csv("train_folds.csv", index=False)
14036 1558
14036 1558
14033 1561
14035 1559
14031 1563
14035 1559
14034 1560
14036 1558
14036 1558
14034 1560
6    14633
9    14532
7    14529
8    14509
5    14466
3    14431
4    14365
1    14358
2    14271
0    14199
Name: kfold, dtype: int64
df.groupby(["kfold"]).count()
id discourse_id discourse_start discourse_end discourse_text discourse_type discourse_type_num predictionstring
kfold
0 14199 14199 14199 14199 14199 14199 14199 14199
1 14358 14358 14358 14358 14358 14358 14358 14358
2 14271 14271 14271 14271 14271 14271 14271 14271
3 14431 14431 14431 14431 14431 14431 14431 14431
4 14365 14365 14365 14365 14365 14365 14365 14365
5 14466 14466 14466 14466 14466 14466 14466 14466
6 14633 14633 14633 14633 14633 14633 14633 14633
7 14529 14529 14529 14529 14529 14529 14529 14529
8 14509 14509 14509 14509 14509 14509 14509 14509
9 14532 14532 14532 14532 14532 14532 14532 14532


Read more:

Related posts: