본문 바로가기

머신러닝

붓꽃 품종 예측하기

사이킷런을 이용하여 붓꽃(Iris) 데이터 품종 예측하기

# 사이킷런 버전 확인
import sklearn
print(sklearn.__version__)

 

** 붓꽃 예측을 위한 사이킷런 필요 모듈 로딩 **

from sklearn.datasets import load_iris  # iris 데이터 로드
from sklearn.tree import DecisionTreeClassifier  # 의사결정나무 분류기
from sklearn.model_selection import train_test_split  # 학습,테스트 데이터 분리

 

데이터 세트를 로딩

import pandas as pd

# 붓꽃 데이터 세트를 로딩합니다. 
iris = load_iris()
iris
# iris.data는 Iris 데이터 세트에서 피처(feature)만으로 된 데이터를 numpy로 가지고 있습니다. 
iris_data = iris.data
iris_data
# iris.target은 붓꽃 데이터 세트에서 레이블(결정 값) 데이터를 numpy로 가지고 있습니다. 
iris_label = iris.target
iris_label
print('iris target명:', iris.target_names)
# 붓꽃 데이터 세트를 자세히 보기 위해 DataFrame으로 변환합니다. 
iris_df = pd.DataFrame(data=iris_data, columns=iris.feature_names)
iris_df['label'] = iris.target

print(iris_df.shape)
iris_df.head()

 

학습 데이터와 테스트 데이터 세트로 분리

X_train, X_test, y_train, y_test = train_test_split(iris_data, iris_label, 
                                                    test_size=0.2, random_state=11)
# 학습 데이터 세트
print(X_train.shape)
print(y_train.shape)

# 테스트 데이터 세트
print(X_test.shape)
print(y_test.shape)

 

"학습 데이터" 세트로 학습(Train) 수행

# DecisionTreeClassifier 객체 생성 
dt_clf = DecisionTreeClassifier(random_state=11)

# 학습 수행 
dt_clf.fit(X_train, y_train)

 

"테스트 데이터" 세트로 예측(Predict) 수행

# 학습이 완료된 DecisionTreeClassifier 객체에서 테스트 데이터 세트로 예측 수행. 
pred = dt_clf.predict(X_test)
print(len(pred))
pred
iris.target_names

예측 정확도 평가

from sklearn.metrics import accuracy_score
print('예측 정확도: {0:.4f}'.format(accuracy_score(y_test, pred)))

 

 

 


사이킷런 내장 데이터인 iris_data 구조를 확인해보자

from sklearn.datasets import load_iris

iris_data = load_iris()

print(type(iris_data))
iris_data
# 붓꽃 데이터 세트의 키들
iris_data.keys()

 

키는 보통 data, target, target_name, feature_names, DESCR로 구성돼 있습니다. 개별 키가 가리키는 의미는 다음과 같습니다.

  • data는 피처의 데이터 세트를 가리킵니다.
  • target은 분류 시 레이블 값, 회귀일 때는 숫자 결괏값 데이터 세트입니다..
  • target_names는 개별 레이블의 이름을 나타냅니다.
  • feature_names는 피처의 이름을 나타냅니다.
  • DESCR은 데이터 세트에 대한 설명과 각 피처의 설명을 나타냅니다.
# iris_data.feature_names 확인
print(type(iris_data.feature_names))
print(len(iris_data.feature_names))
print(iris_data.feature_names)
# iris_data.target_names 확인
print(type(iris_data.target_names))
print(iris_data.target_names.shape)
print(iris_data.target_names)
# iris_data.data 확인
print(type(iris_data.data))
print(iris_data.data.shape)
print(iris_data['data'])
# iris_data.target 확인
print(type(iris_data.target))
print(iris_data.target.shape)
print(iris_data.target)