본문 바로가기

머신러닝

# cross-validation

 원본 데이터를 train settest set으로 분리하여 train set으로만 학습시키고 test set으로만 평가할 경우 overffiting이 생길 가능성이 매우 증가합니다.

 다음 그림과 같이 원본 데이터를 여러 개의 폴드로 나누고, 한 폴드를 validatoin_set, 나머지 폴드들을 train_set으로 설정하여 폴드 수만큼 훈련을 시키면 원본 데이터에 대한 모델의 성능을 보다 정확하게 평가할 수 있습니다.

출쳐: https://scikit-learn.org/stable/modules/cross_validation.html

sklearn에서는 다음과 같이 cross validation을 수행하는 함수를 지원합니다.

KFold는 n_splits(분할하는 Fold의 수), shuffle(데이터의 뒤섞는 유무), random_state를 갖는 객체입니다. 이 객체를 cross_val_score의 cv에 넣어주면 해당 정보에 맞게 데이터를 분할하여 cross_validation을 수행합니다.