人工智能Keras CNN卷积神经网络的图片识别模型训练
ztj100 2024-12-28 16:50 37 浏览 0 评论
CNN卷积神经网络是人工智能的开端,CNN卷积神经网络让计算机能够认识图片,文字,甚至音频与视频。CNN卷积神经网络的基础知识,可以参考:CNN卷积神经网络
LetNet体系结构是卷积神经网络的“第一个图像分类器”。最初设计用于对手写数字进行分类,上期文章我们分享了如何使用keras来进行手写数字的神经网络搭建:Keras人工智能神经网络 Classifier 分类 神经网络搭建
我们也可以轻松地将其扩展到其他类型的图像上,本期使用小雪人的照片,来让神经网络识别雪人
雪人的图片大家可以到网络上自行下载,当然也可以使用爬虫技术来下载
搭建keras神经网络识别图片
from keras.models import Sequential
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.core import Activation
from keras.layers.core import Flatten
from keras.layers.core import Dense
from keras import backend as K
首先导入需要的模块,建立一个神经网络以便后期使用,在一个单独的文件中,命名此神经网络类(lenet.py)
class LeNet:
@staticmethod
def build(width, height, depth, classes):
# 使用Sequential()初始化model
model = Sequential()
inputShape = (height, width, depth)
#tensorflow默认设置
#宽度 :输入图像的宽度
#高度 :输入图像的高度
#深度 :输入图像中的频道数(1个 对于灰度单通道图像, 3 标准RGB图像)
# 若是其他的(Theano),则使用((depth, height, width)
if K.image_data_format() == "channels_first":
inputShape = (depth, height, width)
#建立卷积神经网络 =>然后是 RELU => 然后是max pooling(跟前期分享的tensorflow教程类似)
model.add(Conv2D(20, (5, 5), padding="same",input_shape=inputShape))
model.add(Activation("relu"))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
# 建立卷积神经网络 =>然后是 RELU => 然后是max pooling(第二层)
model.add(Conv2D(50, (5, 5), padding="same"))
model.add(Activation("relu"))
model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
# 增加全连接层
model.add(Flatten())
model.add(Dense(500))
model.add(Activation("relu"))
# softmax classifier 来进行神经网络的分类
model.add(Dense(classes))
model.add(Activation("softmax"))
# return the model
return model
训练keras神经网络
以上建立了keras 的神经网络模型,我们就使用预先下载好的图片来训练神经模型
建立一个train.py文件,插入如下代码,来训练神经网络模型(图片数据里面分成如下2类)
- snowman #我们训练的图片
- notsnowman 增加非雪人图片的训练
import matplotlib
matplotlib.use("Agg")
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import img_to_array
from keras.utils import to_categorical
from lenet import LeNet
from imutils import paths
import matplotlib.pyplot as plt
import numpy as np
import random
import cv2
import os
初始化参数
EPOCHS = 25 #学习的步数
INIT_LR = 1e-3# 学习效率
BS = 32# 每步学习的个数
data = []# 存放图片数据
labels = []# 存放图片标签
imagePaths = sorted(list(paths.list_images("dataset\\")))# 遍历所有的图片
random.seed(42)
random.shuffle(imagePaths) # 打乱图片顺序
初始化参数完成后,需要把所有的图片加载,进行图片数据的整理
for imagePath in imagePaths:
# 加载图片
image = cv2.imread(imagePath)
image = cv2.resize(image, (28, 28)) # resize 到28*28 LeNet所需的空间尺寸
image = img_to_array(image) # 图片转换成array
data.append(image) # 保存图片数据
label = imagePath.split(os.path.sep)[-2] #获取图片标签
label = 1 if label == "snowman" else 0
labels.append(label) # 获取图片标签
预先处理图片
# 把图片数据变成【0.1】
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)
# 设置测试数据与训练数据
#使用75%的数据将数据划分为训练和测试
#用于训练的数据,其余25%用于测试
(trainX, testX, trainY, testY) = train_test_split(data, labels, test_size=0.25, random_state=42)
# 标签转换成向量
trainY = to_categorical(trainY, num_classes=2)
testY = to_categorical(testY, num_classes=2)
# 创建一个图像生成器对象,该对象在图像数据集上执行随机旋转,平移,翻转,修剪和剪切。
#这使我们可以使用较小的数据集,但仍然可以获得较高的结果
aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,
height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,
horizontal_flip=True, fill_mode="nearest")
建立神经网络,进行神经网络训练
#建立model
model = LeNet.build(width=28, height=28, depth=3, classes=2)
opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)
model.compile(loss="binary_crossentropy", optimizer=opt, metrics=["accuracy"])
# 训练神经网络
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
epochs=EPOCHS, verbose=1)
神经网络训练完成后,对神经网络训练的结果进行保存,以便后期使用预训练模型进行图片识别
保存模型,显示训练结果
model.save("lenet.model") # 保存模型
# 显示结果
plt.style.use("ggplot")
plt.figure()
N = EPOCHS
plt.plot(np.arange(0, N), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on snowman/Notsnowman")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig("plot1.JPG")
从训练结果可以看出,loss越来越小,精度越来越高,表明我们的神经网络模型是完全ok。
若想得到更好的训练数据,当然是使用大量的数据进行训练
以上便是我们训练的神经网络模型,下期我们使用预训练模型,对图片进行识别
相关推荐
- Sublime Text 4 稳定版 Build 4113 发布
-
IT之家7月18日消息知名编辑器SublimeText4近日发布了Build4113版本,是SublimeText4的第二个稳定版。IT之家了解到,SublimeTe...
- 【小白课程】openKylin便签贴的设计与实现
-
openKylin便签贴作为侧边栏的一个小插件,提供便捷的文本记录和灵活的页面展示。openKylin便签贴分为两个部分:便签列表...
- 壹啦罐罐 Android 手机里的 Xposed 都装了啥
-
这是少数派推出的系列专题,叫做「我的手机里都装了啥」。这个系列将邀请到不同的玩家,从他们各自的角度介绍手机中最爱的或是日常使用最频繁的App。文章将以「每周一篇」的频率更新,内容范围会包括iOS、...
- 电气自动化专业词汇中英文对照表(电气自动化专业英语单词)
-
专业词汇中英文对照表...
- Python界面设计Tkinter模块的核心组件
-
我们使用一个模块,我们要熟悉这个模块的主要元件。如我们设计一个窗口,我们可以用Tk()来完成创建;一些交互元素,按钮、标签、编辑框用到控件;怎么去布局你的界面,我们可以用到pack()、grid()...
- 以色列发现“死海古卷”新残片(死海古卷是真的吗)
-
编译|陈家琦据艺术新闻网(artnews.com)报道,3月16日,以色列考古学家发现了死海古卷(DeadSeaScrolls)新残片。新出土的羊皮纸残片中包括以希腊文书写的《十二先知书》段落,这...
- 鸿蒙Next仓颉语言开发实战教程:订单列表
-
大家上午好,最近不断有友友反馈仓颉语言和ArkTs很像,所以要注意不要混淆。今天要分享的是仓颉语言开发商城应用的订单列表页。首先来分析一下这个页面,它分为三大部分,分别是导航栏、订单类型和订单列表部分...
- 哪些模块可以用在 Xposed for Lollipop 上?Xposed 模块兼容性解答
-
虽然已经有了XposedforLollipop的安装教程,但由于其还处在alpha阶段,一些Xposed模块能不能依赖其正常工作还未可知。为了解决大家对于模块兼容性的疑惑,笔者尽可能多...
- 利用 Fluid 自制 Mac 版 Overcast 应用
-
我喜爱收听播客,健身、上/下班途中,工作中,甚至是忙着做家务时。大多数情况下我会用MarcoArment开发的Overcast(Freemium)在iPhone上收听,这是我目前最喜爱的Po...
- 浅色Al云食堂APP代码(三)(手机云食堂)
-
以下是进一步优化完善后的浅色AI云食堂APP完整代码,新增了数据可视化、用户反馈、智能推荐等功能,并优化了代码结构和性能。项目结构...
- 实战PyQt5: 121-使用QImage实现一个看图应用
-
QImage简介QImage类提供了独立于硬件的图像表示形式,该图像表示形式可以直接访问像素数据,并且可以用作绘制设备。QImage是QPaintDevice子类,因此可以使用QPainter直接在图...
- 滚动条隐藏及美化(滚动条隐藏但是可以滚动)
-
1、滚动条隐藏背景/场景:在移动端,滑动的时候,会显示默认滚动条,如图1://隐藏代码:/*隐藏滚轮*/.ul-scrool-box::-webkit-scrollbar,.ul-scrool...
- 浅色AI云食堂APP完整代码(二)(ai 食堂)
-
以下是整合后的浅色AI云食堂APP完整代码,包含后端核心功能、前端界面以及优化增强功能。项目采用Django框架开发,支持库存管理、订单处理、财务管理等核心功能,并包含库存预警、数据导出、权限管理等增...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)