일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
- SQLD
- 기초통계
- 데이터 분석
- If
- 클러스터링
- 데이터분석
- 최종 프로젝트
- da
- 팀프로젝트
- 서브쿼리
- 프롬프트 엔지니어링
- GA4
- 기초프로젝트
- pandas
- streamlit
- 시각화
- 군집화
- jd
- lambda
- 히트맵
- Chat GPT
- cross join
- 머신러닝
- Python
- 태블로
- 프로젝트
- data analyst
- SQL
- 전처리
- 크롤링
- Today
- Total
세조목
이커머스 머신러닝 강의 복습(Ch.3 - KNN) 본문
KNN
K Nearest Neighbor
최근접 이웃이라는 의미다.
2024.02.02 - [데이터 분석 공부/머신러닝] - 머신러닝 - 의사결정나무, 랜덤포레스트, KNN, 부스팅 알고리즘(24.02.02)
위 포스팅에 KNN 관련해서 내용을 정리해두었는데
간략하게 얘기해서 KNN은 가장 가까이 있는 이웃(데이터)들끼리 묶는 것이다.
KNN의 K가 바로 묶는 데이터의 개수를 의미하는데
몇 개를 묶어야 가장 예측을 잘 할 지는 실험자가 하나 하나 확인해봐야하며
이렇게 실험자가 그 값을 수동으로 넣어줘야하는 것을 하이퍼 파라미터라고한다.
이번 포스팅에서는 금일 학습한 KNN 예제 내용을 정리하려고한다.
예제를 보면 명목형 변수들이 눈에 띈다.
머신러닝을 하기위해서는 컬럼들간의 스케일을 맞춰줘야하기때문에
전처리 단계에서 스케일링을 진행해준다.
1. 더미화 대상 컬럼 뽑아내기
col_list = []
for i in data.columns:
if data[i].dtype == 'O':
col_list.append(i)
for i in col_list:
print(i, data[i].nunique())
# customerID 컬럼은 제외하기위해서 1부터 끝까지만 인덱싱
col_list = col_list[1:]
데이터 타입이 object인 컬럼들만을 col_list에 넣어주고 각 컬럼의 고유값의 개수가 몇 개인지를 확인한다.
그런다음 customerID 컬럼의 경우 머신러닝 時 필요가 없기때문에 해당 컬럼을 제외한 나머지 컬럼들만을 인덱싱해준다.
2. 결측치 채워주기
data['TotalCharges'].mean()
data['TotalCharges'].median()
# 대다수가 좌측(낮은 가격대)에 분포해있기때문에 평균이 아닌 중앙값 활용
sns.displot(data['TotalCharges'])
data['TotalCharges'].fillna(data['TotalCharges'].median(), inplace=True)
'TotalCharges' 컬럼을 확인해보면 결측치가 있는 것을 확인할 수 있는데
평균 또는 중앙값으로 결측치를 채워줄 수 있다.
예제에서의 경우 좌측(낮은 가격대)에 대다수의 데이터들이 분포해있기때문에
평균이 아닌 중앙값을 활용해준다.
3. 스케일링
MinMax Scaler, Standard Scaler, 그리고 Robust Scaler 이렇게 세가지 스케일링을 진행한다.
● MinMax Scaler
from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler
minmax = MinMaxScaler()
data.drop('customerID', axis=1, inplace=True)
minmax.fit(data)
scaled_data = minmax.transform(data) ← 머신러닝때 활용할 것이기때문에 별도의 변수에 지정해줌
scaled_data = pd.DataFrame(scaled_data, columns=data.columns)
● Standard Scaler
standard = StandardScaler()
standard.fit(data.drop('Churn_Yes', axis=1)) # ← 스케일링 時 종속변수는 제거해야함
scaled_st = standard.transform(data.drop('Churn_Yes', axis=1))
pd.DataFrame(scaled_st, columns=data.drop('Churn_Yes', axis=1).columns)
● Robust Scaler
robust = RobustScaler()
robust.fit(data.drop('Churn_Yes', axis=1))
scaled_rob = robust.transform(data.drop('Churn_Yes', axis=1))
pd.DataFrame(scaled_rob, columns=data.drop('Churn_Yes', axis=1).columns)
4. Train, Test dataset split
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
X = scaled_data.drop('Churn_Yes', axis=1)
y = scaled_data['Churn_Yes']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=100)
knn = KNeighborsClassifier(n_neighbors=10)
knn.fit(X_train, y_train)
pred = knn.predict(X_test)
accuracy_score(y_test, pred)
confusion_matrix(y_test, pred)
f1_score(y_test, pred)
5. 최적의 KNN 찾기
error_list = []
for i in range(1, 101):
knn = KNeighborsClassifier(n_neighbors=i)
knn.fit(X_train, y_train)
pred = knn.predict(X_test)
error_list.append(accuracy_score(y_test, pred))
error_list
# accuracy_score 확인
plt.figure(figsize=(10,5))
sns.lineplot(error_list, marker = 'o', markersize=5, markerfacecolor = 'red')
# 가장 높은 accuracy_score 확인하는 방법 1
error_list.index(max(error_list))
# 가장 높은 accuracy_score 확인하는 방법 2
np.array(error_list).argmax()
KNN의 K값을 하나 하나 넣어보면서 accuracy 값을 측정해볼수도 있지만
너무 많은 시간이 소요되기때문에 함수를 사용한다.
1부터 100까지의 숫자만 넣어봤으며
KNeighborsClassifier() 소괄호 안에 'n_neighbors='에 원하는 숫자를 넣어주면
해당 숫자가 K가 된다.
총 100개의 accuracy 값을 error_list에 넣어주고
어떤 값이 가장 큰 값인지를 확인하기위해 아래 두 가지 방법을 활용할 수 있다.
- error_list.index(max(error_list))
- np.array(error_list).argmax()
'데이터 분석 공부 > 머신러닝' 카테고리의 다른 글
머신러닝 기초 복습(선형 회귀)(24.05.02) (1) | 2024.05.02 |
---|---|
머신러닝 - 클러스터링(계층적 군집화) (0) | 2024.04.02 |
이커머스 머신러닝 강의 복습(Ch.2 - Logistic Regression) (0) | 2024.02.27 |
이커머스 머신러닝 강의 복습(Ch.1 - Linear Regression) (1) | 2024.02.26 |
머신러닝 - 의사결정나무, 랜덤포레스트, KNN, 부스팅 알고리즘(24.02.02) (1) | 2024.02.02 |