paddle几行代码轻松实现CNN卷积神经网络在MNIST上的图像分类
ztj100 2024-11-08 15:07 65 浏览 0 评论
前面几期的文章,我们介绍了paddle飞桨AI框架
PaddlePaddle飞桨深度学习实现手写数字识别任务——模型识别篇
本期我们重点介绍一下如何使用paddle来构建CNN卷积神经网络,并在MNIST数据集上面进行相关的数据加载与训练
----1----
MNIST的数据加载
前面已经提到,每一份MINIST数据都由图片以及标签组成。我们将图片命名为“x”,将标记数字的标签命名为“y”。训练数据集和测试数据集都是同样的结构,例如:训练的图片名为 mnist.train.images 而训练的标签名为 mnist.train.labels。
每一个图片均为28×28像素,我们可以将其理解为一个二维数组的结构:
在实际应用中,保存到本地的数据存储格式多种多样,如MNIST数据集以json格式存储在本地,其数据存储结构如下图 所示。
data包含三个元素的列表:train_set、val_set、 test_set,包括50 000条训练样本、10 000条验证样本、10 000条测试样本。每个样本包含手写的数字图片和对应的标签。
- train_set(训练集):用于确定模型参数。
- val_set(验证集):用于调节模型超参数(如多个网络结构、正则化权重的最优选择)。
- test_set(测试集):用于估计应用效果(没有在模型中应用过的数据,更贴近模型在真实场景应用的效果)。
train_set包含两个元素的列表:train_images、train_labels。
- train_images:[50 000, 784]的二维列表,包含50 000张图片。每张图片用一个长度为784的向量表示,内容是28*28尺寸的像素灰度值(黑白图片)。
- train_labels:[50 000, ]的列表,表示这些图片对应的分类标签,即0~9之间的一个数字。
import os
import random
import paddle
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import gzip
import json
from paddle.nn import Conv2D, MaxPool2D, Linear
import paddle.nn.functional as F
# 定义数据集读取器
def load_data(mode='train'):
# 加载数据
datafile = 'dataset/mnist.json.gz'
print('loading mnist dataset from {} ......'.format(datafile))
data = json.load(gzip.open(datafile))
print('mnist dataset load done')
# 读取到的数据区分训练集,验证集,测试集
train_set, val_set, eval_set = data
# 数据集相关参数,图片高度IMG_ROWS, 图片宽度IMG_COLS
IMG_ROWS = 28
IMG_COLS = 28
if mode == 'train':
# 获得训练数据集
imgs, labels = train_set[0], train_set[1]
elif mode == 'valid':
# 获得验证数据集
imgs, labels = val_set[0], val_set[1]
elif mode == 'eval':
# 获得测试数据集
imgs, labels = eval_set[0], eval_set[1]
else:
raise Exception("mode can only be one of ['train', 'valid', 'eval']")
#校验数据
imgs_length = len(imgs)
assert len(imgs) == len(labels), \
"length of train_imgs({}) should be the same as train_labels({})".format(
len(imgs), len(labels))
# 定义数据集每个数据的序号, 根据序号读取数据
index_list = list(range(imgs_length))
# 读入数据时用到的batchsize
BATCHSIZE = 100
# 定义数据生成器
def data_generator():
if mode == 'train':
random.shuffle(index_list)
imgs_list = []
labels_list = []
for i in index_list:
#神经网络重点修改代码
img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32')
label = np.reshape(labels[i], [1]).astype('int64')
imgs_list.append(img)
labels_list.append(label)
if len(imgs_list) == BATCHSIZE:
yield np.array(imgs_list), np.array(labels_list)
imgs_list = []
labels_list = []
# 如果剩余数据的数目小于BATCHSIZE,
# 则剩余数据一起构成一个大小为len(imgs_list)的mini-batch
if len(imgs_list) > 0:
yield np.array(imgs_list), np.array(labels_list)
return data_generator
与往期代码处理数据不同的是如下2行代码
img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32')
label = np.reshape(labels[i], [1]).astype('int64')
----2----
CNN卷积神经网络的搭建
# 卷积神经网络实现
class MNIST(paddle.nn.Layer):
def __init__(self):
super(MNIST, self).__init__()
# 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_size为5,卷积步长stride=1,padding=2
self.conv1 = Conv2D(in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2)
# 定义池化层,池化核的大小kernel_size为2,池化步长为2
self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)
# 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_size为5,卷积步长stride=1,padding=2
self.conv2 = Conv2D(in_channels=20, out_channels=20, kernel_size=5, stride=1, padding=2)
# 定义池化层,池化核的大小kernel_size为2,池化步长为2
self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)
# 定义一层全连接层,输出维度是10
self.fc = Linear(in_features=980, out_features=10)
# 定义网络前向计算过程,卷积后紧接着使用池化层,最后使用全连接层计算最终输出
# 卷积层激活函数使用Relu
def forward(self, inputs):
x = self.conv1(inputs)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.reshape(x, [x.shape[0], 980])
x = self.fc(x)
return x
神经网络的搭建如上面代码所示,最终CNN输出10个MNIST数据集的数字分类,后期我们讲根据此神经网络进行数字识别
----3----
CNN卷积神经网络的训练与模型保存
def train(model):
model.train()
#调用加载数据的函数,获得MNIST训练数据集
train_loader = load_data('train')
# 使用SGD优化器,learning_rate设置为0.01
opt = paddle.optimizer.SGD(learning_rate=0.01, parameters=model.parameters())
#opt = paddle.optimizer.Adam(learning_rate=0.01, parameters=model.parameters())
# 训练10轮
EPOCH_NUM = 10
# MNIST图像高和宽
IMG_ROWS, IMG_COLS = 28, 28
for epoch_id in range(EPOCH_NUM):
for batch_id, data in enumerate(train_loader()):
#准备数据
images, labels = data
images = paddle.to_tensor(images)
labels = paddle.to_tensor(labels)
#前向计算的过程
predicts = model(images)
#计算损失,使用交叉熵损失函数,取一个批次样本损失的平均值
loss = F.cross_entropy(predicts, labels)
avg_loss = paddle.mean(loss)
#每训练200批次的数据,打印下当前Loss的情况
if batch_id % 200 == 0:
print("epoch: {}, batch: {}, loss is: {}".format(epoch_id, batch_id, avg_loss.numpy()))
#后向传播,更新参数的过程
avg_loss.backward()
# 最小化loss,更新参数
opt.step()
# 清除梯度
opt.clear_grad()
#保存模型参数
paddle.save(model.state_dict(), 'model/mnist_cnn_3.pdparams')
model = MNIST()
train(model)
这里跟前期代码重点关注的是如下代码,loss我们是使用交叉熵损失函数
#计算损失,使用交叉熵损失函数,取一个批次样本损失的平均值
loss = F.cross_entropy(predicts, labels)
loading mnist dataset from dataset/mnist.json.gz ......
mnist dataset load done
epoch: 0, batch: 0, loss is: [2.5198565]
epoch: 0, batch: 200, loss is: [0.28517285]
epoch: 0, batch: 400, loss is: [0.2283621]
epoch: 1, batch: 0, loss is: [0.22373457]
epoch: 1, batch: 200, loss is: [0.12276303]
epoch: 1, batch: 400, loss is: [0.16722086]
epoch: 2, batch: 0, loss is: [0.16171335]
epoch: 2, batch: 200, loss is: [0.07980514]
epoch: 2, batch: 400, loss is: [0.22929193]
epoch: 3, batch: 0, loss is: [0.11811857]
epoch: 3, batch: 200, loss is: [0.08179446]
epoch: 3, batch: 400, loss is: [0.18990536]
epoch: 4, batch: 0, loss is: [0.09782474]
epoch: 4, batch: 200, loss is: [0.1028099]
epoch: 4, batch: 400, loss is: [0.1387132]
epoch: 5, batch: 0, loss is: [0.23205265]
epoch: 5, batch: 200, loss is: [0.16299754]
epoch: 5, batch: 400, loss is: [0.14402886]
epoch: 6, batch: 0, loss is: [0.08868036]
epoch: 6, batch: 200, loss is: [0.0643991]
epoch: 6, batch: 400, loss is: [0.06269972]
epoch: 7, batch: 0, loss is: [0.09487689]
epoch: 7, batch: 200, loss is: [0.04431003]
epoch: 7, batch: 400, loss is: [0.07480995]
epoch: 8, batch: 0, loss is: [0.15115222]
epoch: 8, batch: 200, loss is: [0.07555249]
epoch: 8, batch: 400, loss is: [0.24242447]
epoch: 9, batch: 0, loss is: [0.03684139]
epoch: 9, batch: 200, loss is: [0.04600935]
epoch: 9, batch: 400, loss is: [0.22289713]
神经网络训练完成后,我们把训练的模型进行保存,以便进行数字识别,可以看到使用CNN卷积神经网络,loss已经降到了0.05以下,若多训练几次,此loss会更小
ok,有了此模型,我们便可以利用训练好的模型进行数字识别了
----4----
CNN卷积神经网络手写数字识别
首先,我们准备好需要识别的手写数字与上面训练好的模型,然后我们进行数字识别
第一步 搭建CNN卷积神经网络模型
# 导入图像读取第三方库
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import paddle
from paddle.nn import Linear
import paddle.nn.functional as F
from paddle.nn import Conv2D, MaxPool2D, Linear
# 定义mnist数据识别网络结构
class MNIST(paddle.nn.Layer):
def __init__(self):
super(MNIST, self).__init__()
# 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_size为5,卷积步长stride=1,padding=2
self.conv1 = Conv2D(in_channels=1, out_channels=20, kernel_size=5, stride=1, padding=2)
# 定义池化层,池化核的大小kernel_size为2,池化步长为2
self.max_pool1 = MaxPool2D(kernel_size=2, stride=2)
# 定义卷积层,输出特征通道out_channels设置为20,卷积核的大小kernel_size为5,卷积步长stride=1,padding=2
self.conv2 = Conv2D(in_channels=20, out_channels=20, kernel_size=5, stride=1, padding=2)
# 定义池化层,池化核的大小kernel_size为2,池化步长为2
self.max_pool2 = MaxPool2D(kernel_size=2, stride=2)
# 定义一层全连接层,输出维度是1
self.fc = Linear(in_features=980, out_features=10)
# 定义网络前向计算过程,卷积后紧接着使用池化层,最后使用全连接层计算最终输出
# 卷积层激活函数使用Relu
def forward(self, inputs):
x = self.conv1(inputs)
x = F.relu(x)
x = self.max_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.max_pool2(x)
x = paddle.reshape(x, [x.shape[0], 980])
x = self.fc(x)
return x
搭建的神经网络模型跟训练的代码完全一样,直接复制上面的代码即可
第二步 加载图片
# 读取一张本地的样例图片,转变成模型输入的格式
def load_image(img_path):
# 从img_path中读取图像,并转为灰度图
im = Image.open(img_path).convert('L')
#plt.imshow(im,cmap='gray')
# print(np.array(im))
im = im.resize((28, 28), Image.ANTIALIAS)
im = np.array(im).reshape(1, 1, 28, 28).astype(np.float32)
# 图像归一化,保持和数据集的数据范围一致
im = 1 - im / 255
return im
加载图片后,我们需要对图片进行相关的预处理操作,包括转换到灰度空间,缩放到28*28尺寸,并进行数据的归一化
第三步 加载模型进行预测
# 定义预测过程
model = MNIST()
params_file_path = 'model/mnist_cnn_3.pdparams'
img_path = 'image/example_6.jpg'
# 加载模型参数
param_dict = paddle.load(params_file_path)
model.load_dict(param_dict)
# 灌入数据
model.eval()
tensor_img = load_image(img_path)
#模型反馈10个分类标签的对应概率
result = model(paddle.to_tensor(tensor_img))
print('result',result)
#取概率最大的标签作为预测输出
lab = np.argsort(result.numpy())
print(lab)
print("本次预测的数字是: ", lab[0][-1])
我们加载上面代码训练好的模型,并使用paddle.load进行模型的加载,然后使用model(paddle.to_tensor(tensor_img))函数进行数字的预测,最后打印输出预测的数字
result Tensor(shape=[1, 10], dtype=float32, place=CPUPlace, stop_gradient=False,
[[ 0.68616289, -2.16074777, -0.68914777, -4.40827608, -0.68840426,
1.91024637, 8.90141392, -4.07809448, 1.31746018, -2.95059443]])
[[3 7 9 1 2 4 0 8 5 6]]
本次预测的数字是: 6
结论
相比前期的模型训练与预测,CNN卷积神经网络不仅loss会加速减少,且利用训练好的模型可以精确的识别出训练的数字,本教程并没有完全使用paddle的高级API进行CNN的搭建与预测,其代码相对多一些,有关paddle的高级API,我们后期进行相关技术的分享
相关推荐
- Jquery 详细用法
-
1、jQuery介绍(1)jQuery是什么?是一个js框架,其主要思想是利用jQuery提供的选择器查找要操作的节点,然后将找到的节点封装成一个jQuery对象。封装成jQuery对象的目的有...
- 前端开发79条知识点汇总
-
1.css禁用鼠标事件2.get/post的理解和他们之间的区别http超文本传输协议(HTTP)的设计目的是保证客户机与服务器之间的通信。HTTP的工作方式是客户机与服务器之间的请求-应答协议。...
- js基础面试题92-130道题目
-
92.说说你对作用域链的理解参考答案:作用域链的作用是保证执行环境里有权访问的变量和函数是有序的,作用域链的变量只能向上访问,变量访问到window对象即被终止,作用域链向下访问变量是不被允许的。...
- Web前端必备基础知识点,百万网友:牛逼
-
1、Web中的常见攻击方式1.SQL注入------常见的安全性问题。解决方案:前端页面需要校验用户的输入数据(限制用户输入的类型、范围、格式、长度),不能只靠后端去校验用户数据。一来可以提高后端处理...
- 事件——《JS高级程序设计》
-
一、事件流1.事件流描述的是从页面中接收事件的顺序2.事件冒泡(eventbubble):事件从开始时由最具体的元素(就是嵌套最深的那个节点)开始,逐级向上传播到较为不具体的节点(就是Docu...
- 前端开发中79条不可忽视的知识点汇总
-
过往一些不足的地方,通过博客,好好总结一下。1.css禁用鼠标事件...
- Chrome 开发工具之Network
-
经常会听到比如"为什么我的js代码没执行啊?","我明明发送了请求,为什么反应?","我这个网站怎么加载的这么慢?"这类的问题,那么问题既然存在,就需要去解决它,需要解决它,首先我们得找对导致问题的原...
- 轻量级 React.js 虚拟美化滚动条组件RScroll
-
前几天有给大家分享一个Vue自定义滚动条组件VScroll。今天再分享一个最新开发的ReactPC端模拟滚动条组件RScroll。...
- 一文解读JavaScript事件对象和表单对象
-
前言相信做网站对JavaScript再熟悉不过了,它是一门脚本语言,不同于Python的是,它是一门浏览器脚本语言,而Python则是服务器脚本语言,我们不光要会Python,还要会JavaScrip...
- Python函数参数黑科技:*args与**kwargs深度解析
-
90%的Python程序员不知道,可变参数设计竟能决定函数的灵活性和扩展性!掌握这些技巧,让你的函数适应任何场景!一、函数参数设计的三大进阶技巧...
- 深入理解Python3密码学:详解PyCrypto库加密、解密与数字签名
-
在现代计算领域,信息安全逐渐成为焦点话题。密码学,作为信息保护的关键技术之一,允许我们加密(保密)和解密(解密)数据。...
- 阿里Nacos惊爆安全漏洞,火速升级!(附修复建议)
-
前言好,我是threedr3am,我发现nacos最新版本1.4.1对于User-Agent绕过安全漏洞的serverIdentitykey-value修复机制,依然存在绕过问题,在nacos开启了...
- Python模块:zoneinfo时区支持详解
-
一、知识导图二、知识讲解(一)zoneinfo模块概述...
- Golang开发的一些注意事项(一)
-
1.channel关闭后读的问题当channel关闭之后再去读取它,虽然不会引发panic,但会直接得到零值,而且ok的值为false。packagemainimport"...
- Python鼠标与键盘自动化指南:从入门到进阶——键盘篇
-
`pynput`是一个用于控制和监控鼠标和键盘的Python库...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)