도찐개찐

[머신러닝] 12. KNN 본문

PYTHON/데이터분석

[머신러닝] 12. KNN

도개진 2023. 1. 3. 12:32

KNN

  • k 최근접 이웃(k-nearest neighbors) 알고리즘
  • 머신러닝 분류에 자주 사용되는 대표 알고리즘
  • 얼굴인식, 개인영화추천, 질병 유전자 패턴 식별에 활용
  • KNN의 K는 가장 가까운 이웃 '하나'가 아니고 훈련 데이터 중 새로운 데이터와 가장 가까운 k개의 이웃을 찾는다는 의미
  • 즉, 하나의 관측값은 거리가 가까운 k개의 이웃 관측값들과 비슷한 특성을 갖는다고 가정함
    • 거리를 구할때는 유클리드 거리, 맨해튼 거리, 코사인 유사도, 피어슨 상관계수등이 사용됨
  • 따라서, K개 이웃의 목표변수 중 다수결로 가장 많은 범주에 속한 값을 결과로 반환
  • KNN 알고리즘에서는 k를 얼마로 설정하느냐에 따라 결과와 성능이 달라짐
    • k가 작으면 데이터의 범위가 좁아짐 - 과적합 위험
    • k가 크면 데이터의 범위가 넓어짐 - 일반화 위험
    • 일반적으로 k값은 데이터 건수에 제곱근을 씌운 값
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, f1_score
from sklearn.metrics import recall_score, precision_score

from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
fontpath = '/home/bigdata/py39/lib/python3.9/site-packages/matplotlib/mpl-data/fonts/ttf/NanumGothic.ttf'
fname = mpl.font_manager.FontProperties(fname=fontpath).get_name()

mpl.rcParams['font.family'] = 'NanumGothic'
mpl.rcParams['font.size'] = 12
mpl.rcParams['axes.unicode_minus'] = False

과일, 채소 구분하기

  • KNN 알고리즘을 이용해서 당도, 아삭함을 기준으로 과일인지 채소인지 구분
    • 당도가 6, 아삭함이 4 인 토마토는 채소인가 과일인가?
fresh = pd.read_csv('../data/fresh.csv', encoding='cp949')
fresh.head(3)
  이름 단맛 아삭거림 범주
0 포도 8 5 과일
1 생선 2 2 단백질
2 당근 6 10 채소
fresh.columns = ['name', 'sweet', 'crunky', 'class']
# plt.scatter(fresh.단맛, fresh.아삭거림)
# plt.xticks(rotation='vertical')
plt.scatter(fresh.sweet, fresh.crunky)
plt.grid()
plt.show()

sns.scatterplot(data=fresh, x='sweet', y='crunky', hue='class')
plt.plot(6, 4, 'ro')
for i, n in enumerate(fresh.name):
    plt.annotate(n, (fresh.sweet[i], fresh.crunky[i]))
plt.grid()
plt.show()

print(fresh.columns[3:])
for c in fresh.columns[3:]:
    print(c)
    fresh[c] = pd.Categorical(fresh[c])
    fresh[c] = fresh[c].cat.codes
fresh
Index(['class'], dtype='object')
class
  name sweet crunky class
0 포도 8 5 0
1 생선 2 2 1
2 당근 6 10 2
3 오렌지 7 3 0
4 샐러리 3 8 2
5 치즈 1 1 1
6 오이 2 8 2
7 사과 10 9 0
8 베이컨 1 4 1
9 바나나 10 1 0
10 3 7 2
11 양상추 1 9 2
12 견과류 3 6 1
13 10 7 0
14 새우 2 3 1
# 레이블 인코딩2 - map
from sklearn.preprocessing import LabelEncoder
encoders = LabelEncoder()

# fresh['class'].map(lambda x: encoders.fit_transform(fresh[x]))
fresh['class'] = fresh['class'].map({'과일': 0, '단백질': 1, '채소': 2})
fresh
  name sweet crunky class
0 포도 8 5 NaN
1 생선 2 2 NaN
2 당근 6 10 NaN
3 오렌지 7 3 NaN
4 샐러리 3 8 NaN
5 치즈 1 1 NaN
6 오이 2 8 NaN
7 사과 10 9 NaN
8 베이컨 1 4 NaN
9 바나나 10 1 NaN
10 3 7 NaN
11 양상추 1 9 NaN
12 견과류 3 6 NaN
13 10 7 NaN
14 새우 2 3 NaN
data = fresh.iloc[:, 1:3]
target = fresh['class']
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(data.to_numpy(), target)
knn.score(data.to_numpy(), target)
one = [[6, 4]]
knn.predict(one)

적절한 k 값 찾기

from sklearn.model_selection import cross_val_score
scores = []
for n in range(1, 10+1):
    knn = KNeighborsClassifier(n_neighbors=n)
    score = cross_val_score(knn, data, target, cv=5, scoring='accuracy')
    scores.append(np.mean(score))
plt.plot(range(1, 10+1), scores, 'ro--')
plt.show()

농구선수 게임데이터를 이용해서 포지션 예측

  • 2017 NBA 농구선수의 실제 데이터를 참고
    • player 선수이름
    • pos 포지션
    • 3p 한 경기 평균 3점슛 성공횟수
    • 2p 한 경기 평균 2점슛 성공횟수
    • trb 한 경기 평균 리바운드 성공횟수
    • ast 한 경기 평균 어시스트 성공횟수
    • stl 한 경기 평균 스틸성공횟수
    • blk 한 경기 평균 블로킹 성공횟수
  • c 센터 : 골대 근처 블로킹,리바운드,슛찬스를 만듬
  • sg 슈팅가드 : 주로 장거리에서 슛을 쏴서 점수를 얻음
  • 3점슛, 블록킹수에 따라 포지션이 결정된다고 가정하면
    • 만일, 3점슛이 2.5, 블록킹수 2일 경우 포지션은?
  • 또한, 센터와 슈팅가드를 구분하기 위해 도움이 되는 특성feature은?
bbplayers = pd.read_csv('../data/bbplayer.csv')
bbplayers.head(3)
encoders = LabelEncoder()
bbplayers['Pos'] =  encoders.fit_transform(bbplayers['Pos'])
bbplayers.head()
sns.heatmap(bbplayers.corr(), annot=True, fmt='.2f')
plt.show()
sns.pairplot(bbplayers, diag_kind='kde')
sns.scatterplot(data=bbplayers, x='3P', y='BLK', hue='Pos')
data = bbplayers.iloc[:, 2:]
target = bbplayers.Pos
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(data.to_numpy(), target)
knn.score(data.to_numpy(), target)
sns.scatterplot(data=bbplayers, x='3P', y='BLK', hue='Pos')
plt.plot(2.5, 2, 'ro')

plt.grid()
plt.show()
data = bbplayers.loc[:, ['3P', 'BLK']]
target = bbplayers.Pos
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(data.to_numpy(), target)
knn.score(data.to_numpy(), target)
ond = [[2.5, 2]]
knn.predict(one)
from statsmodels.formula.api import ols
 # + '+'.join(bbplayers.columns).replace('+Pos', '')
model = ols('Pos~3P+BLK', data=pd.DataFrame(bbplayers)).fit()
bbp = pd.read_csv('../data/bbplayer.csv')
bbp.head(3)
bbp.Pos.value_counts()

포지션을 결정짓는 요인 파악

# 시각화1 (스틸 / 2점슛)
sns.scatterplot(data=bbp, x='STL', y='2P', hue='Pos')
plt.show()
# 시각화2 (어시스트 / 2점슛)
sns.scatterplot(data=bbp, x='AST', y='2P', hue='Pos')
plt.show()
# 시각화3 (블로킹 / 3점슛)
sns.scatterplot(data=bbp, x='BLK', y='3P', hue='Pos')
plt.show()
# 시각화4 (리바운드 / 3점슛)
sns.scatterplot(data=bbp, x='TRB', y='3P', hue='Pos')
plt.show()
# 시각화5 (스틸 / 블로킹)
sns.scatterplot(data=bbp, x='STL', y='BLK', hue='Pos')
plt.show()
sns.pairplot(bbp, hue='Pos' ,markers=['o', 's'])
  • 센터/슈팅가드를 구분지을 명확한 경계 없음 : 2점슛, 스틸, 어시스트
  • 센터/슈팅가드를 구분지을 명확한 경계 있음 : 3점슛, 리바운드, 블로킹
    • 단, 스틸, 블로킹으로 조합했을때도 명확한 경계가 존재
    • 하지만, 슈팅가드/센터라는 포지션을 구분하기에
      적당한 속성이라 보기 힘듦
data = bbp.loc[:, ['3P', 'BLK']]
target = bbp.Pos.map({'C': 0, 'SG': 1})
X_train, X_test, y_train, y_test = \
    train_test_split(data, target, train_size=0.7,
                     stratify=target, random_state=2211211235)
scores = []
for n in range(1, 10+1):
    knn = KNeighborsClassifier(n_neighbors=n)
    score = cross_val_score(knn, X_train, y_train, cv=5, scoring='accuracy')
    scores.append(np.mean(score))
plt.plot(range(1, 10+1), scores, 'ro--')
plt.show() # => 적절한 k값 : 5
knn = KNeighborsClassifier(n_neighbors=5)
knn.fit(X_train.to_numpy(), y_train)
pred = knn.predict(X_test.to_numpy())

print(knn.score(X_train.to_numpy(), y_train), accuracy_score(y_test, pred))
player = [[2.5, 2], [3, 1]]
knn.predict(player)

 

728x90

'PYTHON > 데이터분석' 카테고리의 다른 글

[머신러닝] 14. 앙상블  (0) 2023.01.03
[머신러닝] 13. SVM  (0) 2023.01.03
[머신러닝] 11. 나이브베이즈 분석  (1) 2023.01.03
[머신러닝] 10. 엔트로피  (1) 2023.01.03
[머신러닝] 09. 의사결정 나무  (0) 2023.01.03
Comments