# 上載手寫辨識數據
# sklearn.datasets.load_digits
from sklearn import datasets
mnist = datasets.load_digits()
mnist.keys()
data = mnist.images
mnist.target
mnist.target.shape
data = mnist.images
target = mnist.target.reshape(-1,1)
data.shape , target.shape
%matplotlib inline
index = 3
import matplotlib.pyplot as plt
plt.figure(figsize=(1,1))
plt.imshow(data[index], cmap='gray_r')
plt.axis('off')
plt.show()
# 檢查 X 與 y
index = 0
print(data[index])
print('-'*35)
print(mnist.target[index])
print('-reshape-')
print(data[index].reshape(-1)) # .ravel() , .flatten()
# 分類 mnist 分十類別 [0....9]
from sklearn.model_selection import train_test_split
X = data.reshape(1797,64)
y = target.flatten()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# hyper-parameter <-- GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier
from sklearn.ensemble import ExtraTreesClassifier
# clf = LogisticRegression(random_state=0, solver='newton-cg', multi_class='multinomial') # 0.972
# clf = MLPClassifier() # 0.975
clf = ExtraTreesClassifier(n_estimators=2000) # 0.9805
clf.fit(X_train, y_train)
# 評分
print('X_train ', clf.score(X_train , y_train) )
print('X_test ', clf.score(X_test , y_test) )
# 找出那些沒有預測正確
import numpy as np
import matplotlib.pyplot as plt
y_pred = clf.predict(X_test)
diff = (y_pred != y_test)
print('Ground Truth', y_test[diff])
print('Predictive ', y_pred[diff])
for i in np.arange(X_test.shape[0])[diff]:
plt.figure(figsize=(1,1))
plt.axis('off')
plt.imshow(X_test[i].reshape(8,8) , cmap='gray')
plt.title( str(y_test[i])+ '-->' + str(y_pred[i]) )
plt.show()
X_test.shape