sklearn中使用KNN方法对手写MNIST数据集分类
KNN为相近邻接点算法,时间、空间复杂度较高
此处为练习
1. 导入相关库
1
2
3
4
5
from sklearn.datasets import fetch_mldata
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
2. 加载数据集
此处使用MNIST original数据集。数据集中,data为(70000,784)的二值化数据,target为(70000,1)的向量。
数据集特征有:
- 有70000张数据
- 图片格式为28*28的数据
- 目标为0-9的整数
1
2
3
4
# 导入手写数据集
data = fetch_mldata("MNIST original",data_home="data")
print(data["DESCR"])
x,y = data['data'],data['target']
可以使用
matplotlib
可视化
1 2 3 4 index = 12000 single_img = x[index].reshape(28,28) plt.imshow(single_img,cmap=matplotlib.cm.binary) # cmap使用二值化 print("label is:",y[index])
3. 分割训练/测试集
MNIST original本身就分割了训练集,测试集。前60000为训练集,后10000为测试集
1
x_train,x_test,y_train,y_test = x[:60000],x[60000:],y[:60000],y[60000:]
4.数据集增强
为了保证训练效果,此处进行顺序打乱
1
2
3
shuffle_index = np.random.permutation(60000)
print(shuffle_index)
x_train,y_train = x_train[shuffle_index],y_train[shuffle_index]
5.训练KNN模型
1
2
3
4
# 邻居 k =5
k = 5
k_classifier = KNeighborsClassifier(n_neighbors=k)
k_classifier.fit(X,y)
6.测试结果
1
2
3
# KNN预测
k_classifier.predict([x_test[1000]])
# 预测结果:array([1.]),表明预测为1