使用KNN方法对手写MNIST数据集分类

python sklearn

Posted by MetaNetworks on September 8, 2020
本页面总访问量

sklearn中使用KNN方法对手写MNIST数据集分类

sklearn模块图

KNN为相近邻接点算法,时间、空间复杂度较高

此处为练习

1. 导入相关库

1
2
3
4
5
from sklearn.datasets import fetch_mldata
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from sklearn.neighbors import KNeighborsClassifier

2. 加载数据集

此处使用MNIST original数据集。数据集中,data为(70000,784)的二值化数据,target为(70000,1)的向量。

数据集特征有:

  • 有70000张数据
  • 图片格式为28*28的数据
\[28 * 28 = 784\]
  • 目标为0-9的整数
1
2
3
4
# 导入手写数据集
data = fetch_mldata("MNIST original",data_home="data")
print(data["DESCR"])
x,y = data['data'],data['target']

可以使用matplotlib可视化

1
2
3
4
index = 12000
single_img = x[index].reshape(28,28)
plt.imshow(single_img,cmap=matplotlib.cm.binary) # cmap使用二值化
print("label is:",y[index])

可视化

3. 分割训练/测试集

MNIST original本身就分割了训练集,测试集。前60000为训练集,后10000为测试集

1
x_train,x_test,y_train,y_test = x[:60000],x[60000:],y[:60000],y[60000:]

4.数据集增强

为了保证训练效果,此处进行顺序打乱

1
2
3
shuffle_index = np.random.permutation(60000)
print(shuffle_index)
x_train,y_train = x_train[shuffle_index],y_train[shuffle_index]

5.训练KNN模型

1
2
3
4
# 邻居 k =5
k = 5
k_classifier = KNeighborsClassifier(n_neighbors=k)
k_classifier.fit(X,y)

6.测试结果

1
2
3
# KNN预测
k_classifier.predict([x_test[1000]])
# 预测结果:array([1.]),表明预测为1