Pytorch、Tensorflow、Keras 框架下实现KNN算法(MNIST数据集)附详解代码_knn pytorch-程序员宅基地

技术标签: tensorflow  python  近邻算法  机器学习算法  pytorch  keras  

Pytorch、Tensorflow、Keras框架下实现KNN算法(MNIST数据集)附详解代码

K最近邻法(KNN)是最常见的监督分类算法,其中根据K值的不同取值,模型会有不一样的效率,但是并不是k值越大或者越小,模型效率越高,而是根据数据集的不同,使用交叉验证,得出最优K值。

Python—KNN分类算法(详解)
欧式距离的快捷计算方法

基于Pytorch实现KNN算法:

#******************************************************************
#从torchvision中引入常用数据集(MNIST),以及常用的预处理操作(transfrom)
from torchvision import datasets, transforms
#引入numpy计算矩阵
import numpy as np
#引入模型评估指标 accuracy_score
from sklearn.metrics import accuracy_score
import torch
#引入进度条设置以及时间设置
from tqdm import tqdm
import time

# 定义KNN函数
def KNN(train_x, train_y, test_x, test_y, k):
    #获取当前时间
    since = time.time()
    #可以将m,n理解为求其数据个数,属于torch.tensor类
    m = test_x.size(0)
    n = train_x.size(0)

    # 计算欧几里得距离矩阵,矩阵维度为m*n;
    print("计算距离矩阵")

    #test,train本身维度是m*1, **2为对每个元素平方,sum(dim=1,对行求和;keepdim =True时保持二维,
    # 而False对应一维,expand是改变维度,使其满足 m * n)
    xx = (test_x ** 2).sum(dim=1, keepdim=True).expand(m, n)
    #最后增添了转置操作
    yy = (train_x ** 2).sum(dim=1, keepdim=True).expand(n, m).transpose(0, 1)
    #计算近邻距离公式
    dist_mat = xx + yy - 2 * test_x.matmul(train_x.transpose(0, 1))
    #对距离进行排序
    mink_idxs = dist_mat.argsort(dim=-1)
    #定义一个空列表
    res = []
    for idxs in mink_idxs:
        # voting
        #代码下方会附上解释np.bincount()函数的博客
        res.append(np.bincount(np.array([train_y[idx] for idx in idxs[:k]])).argmax())

    assert len(res) == len(test_y)
    print("acc", accuracy_score(test_y, res))
    #计算运行时长
    time_elapsed = time.time() - since
    print('KNN mat training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

#欧几里得距离计算公式
def cal_distance(x, y):
    return torch.sum((x - y) ** 2) ** 0.5
# KNN的迭代函数
def KNN_by_iter(train_x, train_y, test_x, test_y, k):
    since = time.time()

    # 计算距离
    res = []
    for x in tqdm(test_x):
        dists = []
        for y in train_x:
            dists.append(cal_distance(x, y).view(1))
        #torch.cat()用来拼接tensor
        idxs = torch.cat(dists).argsort()[:k]
        res.append(np.bincount(np.array([train_y[idx] for idx in idxs])).argmax())

    # print(res[:10])
    print("acc", accuracy_score(test_y, res))

    time_elapsed = time.time() - since
    print('KNN iter training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))


if __name__ == "__main__":
    #加载数据集(下载数据集)
    train_dataset = datasets.MNIST(root="./data", download= True, transform=transforms.ToTensor(), train=True)
    test_dataset = datasets.MNIST(root="./data", download= True, transform=transforms.ToTensor(), train=False)

    # 组织训练,测试数据
    train_x = []
    train_y = []
    for i in range(len(train_dataset)):
        img, target = train_dataset[i]
        train_x.append(img.view(-1))
        train_y.append(target)

        if i > 5000:
            break

    # print(set(train_y))

    test_x = []
    test_y = []
    for i in range(len(test_dataset)):
        img, target = test_dataset[i]
        test_x.append(img.view(-1))
        test_y.append(target)

        if i > 200:
            break

    print("classes:", set(train_y))

    KNN(torch.stack(train_x), train_y, torch.stack(test_x), test_y, 7)
    KNN_by_iter(torch.stack(train_x), train_y, torch.stack(test_x), test_y, 7)

运行结果:

classes: {
    0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
计算距离矩阵
acc 0.9405940594059405
KNN mat training complete in 0m 0s
100%|██████████| 202/202 [00:26<00:00,  7.61it/s]
acc 0.9405940594059405
KNN iter training complete in 0m 27s

Process finished with exit code 0

参考博客:
numpy.bincount详解
Pytorch中torch.cat与torch.stack有什么区别?

基于Tensorflow实现KNN算法

#__author__ = 'HelloWorld怎么写'
#******************************************************************
#导入相关包,相关API有的只适合TF1
import numpy as np
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

#加载MNIST数据集
from tensorflow.examples.tutorials.mnist import input_data
def loadMNIST():
    #获取数据,采用ONE_HOT形式
    mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
    return mnist
#定义KNN算法
def KNN(mnist):
    #训练集取前10000,测试集取500
    train_x, train_y = mnist.train.next_batch(10000)
    test_x, test_y = mnist.train.next_batch(500)
    #计算图输入占位符,[784表示列数],结果返回tensor类型
    xtr = tf.placeholder(tf.float32, [None, 784])
    xte = tf.placeholder(tf.float32, [784])
    #计算欧几里得距离;tf.negative(x)返回一个张量;tf.add()实现列表元素求和;
    # tf.reduce_sum(a,reduction_indices:axis),a为要减少的张量,axis的废弃名称
    distance = tf.sqrt(tf.reduce_sum(tf.pow(tf.add(xtr, tf.negative(xte)), 2), reduction_indices=1))
    #返回纵列的最小值
    pred = tf.argmin(distance, 0)
    #变量初始化
    init = tf.initialize_all_variables()

    sess = tf.Session()
    sess.run(init)
    #求模型准确率
    right = 0
    for i in range(500):
        ansIndex = sess.run(pred, {
    xtr: train_x, xte: test_x[i, :]})
        print('prediction is ', np.argmax(train_y[ansIndex]))
        print('true value is ', np.argmax(test_y[i]))
        if np.argmax(test_y[i]) == np.argmax(train_y[ansIndex]):
            right += 1.0
    accracy = right / 500.0
    print(accracy)


if __name__ == "__main__":
    #实例化函数
    mnist = loadMNIST()
    KNN(mnist)

运行结果:

...
prediction is  7
true value is  7
prediction is  0
true value is  0
0.942
Process finished with exit code 0

参考博客:
Tensorflow 利用最近邻算法实现Mnist的识别
基于TensorFlow的K近邻(KNN)分类器实现——以MNIST为例
tensorflow实现KNN识别MNIST

基于Keras实现KNN算法

#__author__ = 'HelloWorld怎么写'
#******************************************************************
#导入相关包
from keras.models import Sequential
from keras.layers import Dense, Dropout
from keras.utils import np_utils
from keras.datasets import mnist
import os
#使用GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

#加载MNIST数据
def load_data():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    #选取训练集前10000张
    number = 10000
    x_train = x_train[0:number]
    y_train = y_train[0:number]
    #进行预处理
    x_train = x_train.reshape(number, 28 * 28)
    x_test = x_test.reshape(x_test.shape[0], 28 * 28)
    x_train = x_train.astype('float32')
    x_test = x_test.astype('float32')
    #np_utils.to_categorical()函数将y_train转变成ONE-HOT形式
    y_train = np_utils.to_categorical(y_train, 10)
    y_test = np_utils.to_categorical(y_test, 10)
    #进行标准化,x_train属于0-255,除以255,变成0-1的值
    x_train = x_train / 255
    x_test = x_test / 255

    return (x_train, y_train), (x_test, y_test)


(x_train, y_train), (x_test, y_test) = load_data()

#Keras序贯模型
model = Sequential()
#输入数据,定义数据尺寸,units 为输出空间维度;激活函数
model.add(Dense(input_dim=28 * 28, units=689, activation='relu'))
#dropout层
model.add(Dropout(0.2))
model.add(Dense(units=689, activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(units=689, activation='relu'))
model.add(Dropout(0.2))
#输出层
model.add(Dense(output_dim=10, activation='softmax'))
#配置训练方法,损失函数、优化器、评估指标
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
#训练模型
model.fit(x_train, y_train, batch_size=10000, epochs=20)
#评估指标
res1 = model.evaluate(x_train, y_train, batch_size=10000)
print("\n Train Acc :", res1[1])
res2 = model.evaluate(x_test, y_test, batch_size=10000)
print("\n Test Acc :", res2[1])

运行结果:

...
10000/10000 [==============================] - 0s 17us/step - loss: 0.2658 - accuracy: 0.9210

10000/10000 [==============================] - 0s 12us/step

 Train Acc : 0.940500020980835

10000/10000 [==============================] - 0s 7us/step

 Test Acc : 0.9265000224113464

Process finished with exit code 0

参考博客:
Keras MNIST 手写数字识别数据集
Keras入门级MNIST手写数字识别超级详细教程

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_46236212/article/details/122219292

智能推荐

c# 调用c++ lib静态库_c#调用lib-程序员宅基地

文章浏览阅读2w次,点赞7次,收藏51次。四个步骤1.创建C++ Win32项目动态库dll 2.在Win32项目动态库中添加 外部依赖项 lib头文件和lib库3.导出C接口4.c#调用c++动态库开始你的表演...①创建一个空白的解决方案,在解决方案中添加 Visual C++ , Win32 项目空白解决方案的创建:添加Visual C++ , Win32 项目这......_c#调用lib

deepin/ubuntu安装苹方字体-程序员宅基地

文章浏览阅读4.6k次。苹方字体是苹果系统上的黑体,挺好看的。注重颜值的网站都会使用,例如知乎:font-family: -apple-system, BlinkMacSystemFont, Helvetica Neue, PingFang SC, Microsoft YaHei, Source Han Sans SC, Noto Sans CJK SC, W..._ubuntu pingfang

html表单常见操作汇总_html表单的处理程序有那些-程序员宅基地

文章浏览阅读159次。表单表单概述表单标签表单域按钮控件demo表单标签表单标签基本语法结构<form action="处理数据程序的url地址“ method=”get|post“ name="表单名称”></form><!--action,当提交表单时,向何处发送表单中的数据,地址可以是相对地址也可以是绝对地址--><!--method将表单中的数据传送给服务器处理,get方式直接显示在url地址中,数据可以被缓存,且长度有限制;而post方式数据隐藏传输,_html表单的处理程序有那些

PHP设置谷歌验证器(Google Authenticator)实现操作二步验证_php otp 验证器-程序员宅基地

文章浏览阅读1.2k次。使用说明:开启Google的登陆二步验证(即Google Authenticator服务)后用户登陆时需要输入额外由手机客户端生成的一次性密码。实现Google Authenticator功能需要服务器端和客户端的支持。服务器端负责密钥的生成、验证一次性密码是否正确。客户端记录密钥后生成一次性密码。下载谷歌验证类库文件放到项目合适位置(我这边放在项目Vender下面)https://github.com/PHPGangsta/GoogleAuthenticatorPHP代码示例://引入谷_php otp 验证器

【Python】matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距-程序员宅基地

文章浏览阅读4.3k次,点赞5次,收藏11次。matplotlib.plot画图横坐标混乱及间隔处理_matplotlib更改横轴间距

docker — 容器存储_docker 保存容器-程序员宅基地

文章浏览阅读2.2k次。①Storage driver 处理各镜像层及容器层的处理细节,实现了多层数据的堆叠,为用户 提供了多层数据合并后的统一视图②所有 Storage driver 都使用可堆叠图像层和写时复制(CoW)策略③docker info 命令可查看当系统上的 storage driver主要用于测试目的,不建议用于生成环境。_docker 保存容器

随便推点

网络拓扑结构_网络拓扑csdn-程序员宅基地

文章浏览阅读834次,点赞27次,收藏13次。网络拓扑结构是指计算机网络中各组件(如计算机、服务器、打印机、路由器、交换机等设备)及其连接线路在物理布局或逻辑构型上的排列形式。这种布局不仅描述了设备间的实际物理连接方式,也决定了数据在网络中流动的路径和方式。不同的网络拓扑结构影响着网络的性能、可靠性、可扩展性及管理维护的难易程度。_网络拓扑csdn

JS重写Date函数,兼容IOS系统_date.prototype 将所有 ios-程序员宅基地

文章浏览阅读1.8k次,点赞5次,收藏8次。IOS系统Date的坑要创建一个指定时间的new Date对象时,通常的做法是:new Date("2020-09-21 11:11:00")这行代码在 PC 端和安卓端都是正常的,而在 iOS 端则会提示 Invalid Date 无效日期。在IOS年月日中间的横岗许换成斜杠,也就是new Date("2020/09/21 11:11:00")通常为了兼容IOS的这个坑,需要做一些额外的特殊处理,笔者在开发的时候经常会忘了兼容IOS系统。所以就想试着重写Date函数,一劳永逸,避免每次ne_date.prototype 将所有 ios

如何将EXCEL表导入plsql数据库中-程序员宅基地

文章浏览阅读5.3k次。方法一:用PLSQL Developer工具。 1 在PLSQL Developer的sql window里输入select * from test for update; 2 按F8执行 3 打开锁, 再按一下加号. 鼠标点到第一列的列头,使全列成选中状态,然后粘贴,最后commit提交即可。(前提..._excel导入pl/sql

Git常用命令速查手册-程序员宅基地

文章浏览阅读83次。Git常用命令速查手册1、初始化仓库git init2、将文件添加到仓库git add 文件名 # 将工作区的某个文件添加到暂存区 git add -u # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,不处理untracked的文件git add -A # 添加所有被tracked文件中被修改或删除的文件信息到暂存区,包括untracked的文件...

分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120-程序员宅基地

文章浏览阅读202次。分享119个ASP.NET源码总有一个是你想要的_千博二手车源码v2023 build 1120

【C++缺省函数】 空类默认产生的6个类成员函数_空类默认产生哪些类成员函数-程序员宅基地

文章浏览阅读1.8k次。版权声明:转载请注明出处 http://blog.csdn.net/irean_lau。目录(?)[+]1、缺省构造函数。2、缺省拷贝构造函数。3、 缺省析构函数。4、缺省赋值运算符。5、缺省取址运算符。6、 缺省取址运算符 const。[cpp] view plain copy_空类默认产生哪些类成员函数

推荐文章

热门文章

相关标签