Stratified K 폴드는 불규현향 분포도를 가진 target 데이터 집합을 위한 k-fold 방식 이다.
만약 공장의 오작동 데이터를 예측한다고 가정해 본다. 공장의 작동 데이터 중 오작동, 작동을 구분을 1(오작동), 0(작동)로 하여 피쳐를 준다고 가정하자. 이 공장은 잘 돌아가고 있는 공장이므로 오작동 데이터는 작동 데이터보다 훨씬 작을 것이다. 만약 작동 데이터가 만건이 있다고 하면 오작동 데이터는 10건정도 있다고 가정해본다. 이렇게 작동 데이터에 비해 적은 비율로 오작동 데이터가 있다면 k-fold로 랜덤하게 데이터를 뽑을 때 0, 1의 비율이 제대로 반영되지 못한다.
이 때 예측을 하기 위해 중요한 것은 작동 데이터가 아니라 오작동 데이터이다. 그러므로 원본데이터와 유사한 오작동 레이블 값의 분포를 학습세트, 테스트세트에 유지하는 것이 중요하다.
와인데이터의 피쳐와 타겟으로 와인데이터프레임을 만든 다음
레이블의 각각의 갯수를 확인한다.
wine_df = pd.DataFrame(wine_features, columns=wine.feature_names)
wine_df['label'] = wine.target
wine_df
wine_df['label'].value_counts()
1 71
0 59
2 48
각 검증시마다 생성되는 학습데이터, 검증데이터 값의 분포도를 확인해본다.
밑의 결과를 확인해보면 학습 레이블과 검증 레이블이 매우 다르다. 첫번째에는 학습레이블이 1은 70개, 2는 48개 있는데 검증 레이블은 0이 59개, 1이 1개 있다. 이렇게 되면 학습레이블에서는 없는 0은 예측할 수 없는 상황이 만들어진다.
kfold = KFold(n_splits=3)
iter_num = 0
for train_index, test_index in kfold.split(wine_df):
iter_num += 1
label_train = wine_df['label'].iloc[train_index]
label_test = wine_df['label'].iloc[test_index]
print('교차 검증 : \n', iter_num)
print('학습 레이블 데이터 분포 : \n', label_train.value_counts())
print('검증 레이블 데이터 분포 : \n', label_test.value_counts())
교차 검증 : 1
학습 레이블 데이터 분포 :
1 70
2 48
Name: label, dtype: int64
검증 레이블 데이터 분포 :
0 59
1 1
Name: label, dtype: int64
교차 검증 : 2
학습 레이블 데이터 분포 :
0 59
2 48
1 12
Name: label, dtype: int64
검증 레이블 데이터 분포 :
1 59
Name: label, dtype: int64
교차 검증 : 3
학습 레이블 데이터 분포 :
1 60
0 59
Name: label, dtype: int64
검증 레이블 데이터 분포 :
2 48
1 11
이 때 StratifiedKFold가 해결책이 될 수 있다. StratifiedKFold를 통해 레이블의 분포도를 반영해보자.
밑에 결과를 보면 일정하게 반영이 된 것을 알 수 있다.
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(n_splits=3)
iter_num = 0
for train_index, test_index in skf.split(wine_df, wine_df['label']):
iter_num += 1
label_train = wine_df['label'].iloc[train_index]
label_test = wine_df['label'].iloc[test_index]
print('교차 검증 : ', iter_num)
print('학습 레이블 데이터 분포 : \n', label_train.value_counts())
print('검증 레이블 데이터 분포 : \n', label_test.value_counts())
교차 검증 : 1
학습 레이블 데이터 분포 :
1 47
0 39
2 32
Name: label, dtype: int64
검증 레이블 데이터 분포 :
1 24
0 20
2 16
Name: label, dtype: int64
교차 검증 : 2
학습 레이블 데이터 분포 :
1 48
0 39
2 32
Name: label, dtype: int64
검증 레이블 데이터 분포 :
1 23
0 20
2 16
Name: label, dtype: int64
교차 검증 : 3
학습 레이블 데이터 분포 :
1 47
0 40
2 32
Name: label, dtype: int64
검증 레이블 데이터 분포 :
1 24
0 19
2 16