百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

人工智能Keras CNN卷积神经网络的图片识别模型训练

ztj100 2024-12-28 16:50 42 浏览 0 评论

CNN卷积神经网络是人工智能的开端,CNN卷积神经网络让计算机能够认识图片,文字,甚至音频与视频。CNN卷积神经网络的基础知识,可以参考:CNN卷积神经网络

LetNet体系结构是卷积神经网络的“第一个图像分类器”。最初设计用于对手写数字进行分类,上期文章我们分享了如何使用keras来进行手写数字的神经网络搭建:Keras人工智能神经网络 Classifier 分类 神经网络搭建

我们也可以轻松地将其扩展到其他类型的图像上,本期使用小雪人的照片,来让神经网络识别雪人

雪人的图片大家可以到网络上自行下载,当然也可以使用爬虫技术来下载

搭建keras神经网络识别图片

from keras.models import Sequential
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.core import Activation
from keras.layers.core import Flatten
from keras.layers.core import Dense
from keras import backend as K

首先导入需要的模块,建立一个神经网络以便后期使用,在一个单独的文件中,命名此神经网络类(lenet.py)

class LeNet:
	@staticmethod
	def build(width, height, depth, classes):
		# 使用Sequential()初始化model
		model = Sequential()
		inputShape = (height, width, depth) 
		#tensorflow默认设置
    #宽度 :输入图像的宽度
    #高度 :输入图像的高度
    #深度 :输入图像中的频道数(1个 对于灰度单通道图像, 3 标准RGB图像)
    # 若是其他的(Theano),则使用((depth, height, width)
		if K.image_data_format() == "channels_first":
			inputShape = (depth, height, width)
		#建立卷积神经网络  =>然后是 RELU => 然后是max pooling(跟前期分享的tensorflow教程类似)
		model.add(Conv2D(20, (5, 5), padding="same",input_shape=inputShape))
		model.add(Activation("relu"))
		model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
		# 建立卷积神经网络  =>然后是 RELU => 然后是max pooling(第二层)
		model.add(Conv2D(50, (5, 5), padding="same"))
		model.add(Activation("relu"))
		model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
		# 增加全连接层
		model.add(Flatten())
		model.add(Dense(500))
		model.add(Activation("relu"))
		# softmax classifier 来进行神经网络的分类
		model.add(Dense(classes))
		model.add(Activation("softmax"))
		# return the model
		return model

训练keras神经网络

以上建立了keras 的神经网络模型,我们就使用预先下载好的图片来训练神经模型

建立一个train.py文件,插入如下代码,来训练神经网络模型(图片数据里面分成如下2类)

  1. snowman #我们训练的图片
  2. notsnowman 增加非雪人图片的训练
import matplotlib
matplotlib.use("Agg")
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import img_to_array
from keras.utils import to_categorical
from lenet import LeNet
from imutils import paths
import matplotlib.pyplot as plt
import numpy as np
import random
import cv2
import os

初始化参数

EPOCHS = 25 #学习的步数
INIT_LR = 1e-3# 学习效率
BS = 32# 每步学习的个数
data = []# 存放图片数据
labels = []# 存放图片标签
imagePaths = sorted(list(paths.list_images("dataset\\")))# 遍历所有的图片
random.seed(42)
random.shuffle(imagePaths) # 打乱图片顺序

初始化参数完成后,需要把所有的图片加载,进行图片数据的整理

for imagePath in imagePaths:
    # 加载图片
    image = cv2.imread(imagePath)
    image = cv2.resize(image, (28, 28)) # resize 到28*28 LeNet所需的空间尺寸
    image = img_to_array(image) # 图片转换成array
    data.append(image) # 保存图片数据
    label = imagePath.split(os.path.sep)[-2] #获取图片标签
    label = 1 if label == "snowman" else 0
    labels.append(label) # 获取图片标签

预先处理图片

# 把图片数据变成【0.1】
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)
# 设置测试数据与训练数据
#使用75%的数据将数据划分为训练和测试
#用于训练的数据,其余25%用于测试
(trainX, testX, trainY, testY) = train_test_split(data, labels, test_size=0.25, random_state=42)
# 标签转换成向量
trainY = to_categorical(trainY, num_classes=2)
testY = to_categorical(testY, num_classes=2)
# 创建一个图像生成器对象,该对象在图像数据集上执行随机旋转,平移,翻转,修剪和剪切。
#这使我们可以使用较小的数据集,但仍然可以获得较高的结果
aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,
                         height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,
                         horizontal_flip=True, fill_mode="nearest")

建立神经网络,进行神经网络训练

#建立model
model = LeNet.build(width=28, height=28, depth=3, classes=2)
opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)
model.compile(loss="binary_crossentropy", optimizer=opt, metrics=["accuracy"])
# 训练神经网络
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
                        validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
                        epochs=EPOCHS, verbose=1)

神经网络训练完成后,对神经网络训练的结果进行保存,以便后期使用预训练模型进行图片识别

保存模型,显示训练结果

model.save("lenet.model") # 保存模型
# 显示结果
plt.style.use("ggplot")
plt.figure()
N = EPOCHS
plt.plot(np.arange(0, N), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on snowman/Notsnowman")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig("plot1.JPG")

从训练结果可以看出,loss越来越小,精度越来越高,表明我们的神经网络模型是完全ok。

若想得到更好的训练数据,当然是使用大量的数据进行训练

以上便是我们训练的神经网络模型,下期我们使用预训练模型,对图片进行识别


相关推荐

sharding-jdbc实现`分库分表`与`读写分离`

一、前言本文将基于以下环境整合...

三分钟了解mysql中主键、外键、非空、唯一、默认约束是什么

在数据库中,数据表是数据库中最重要、最基本的操作对象,是数据存储的基本单位。数据表被定义为列的集合,数据在表中是按照行和列的格式来存储的。每一行代表一条唯一的记录,每一列代表记录中的一个域。...

MySQL8行级锁_mysql如何加行级锁

MySQL8行级锁版本:8.0.34基本概念...

mysql使用小技巧_mysql使用入门

1、MySQL中有许多很实用的函数,好好利用它们可以省去很多时间:group_concat()将取到的值用逗号连接,可以这么用:selectgroup_concat(distinctid)fr...

MySQL/MariaDB中如何支持全部的Unicode?

永远不要在MySQL中使用utf8,并且始终使用utf8mb4。utf8mb4介绍MySQL/MariaDB中,utf8字符集并不是对Unicode的真正实现,即不是真正的UTF-8编码,因...

聊聊 MySQL Server 可执行注释,你懂了吗?

前言MySQLServer当前支持如下3种注释风格:...

MySQL系列-源码编译安装(v5.7.34)

一、系统环境要求...

MySQL的锁就锁住我啦!与腾讯大佬的技术交谈,是我小看它了

对酒当歌,人生几何!朝朝暮暮,唯有己脱。苦苦寻觅找工作之间,殊不知今日之事乃我心之痛,难道是我不配拥有工作嘛。自面试后他所谓的等待都过去一段时日,可惜在下京东上的小金库都要见低啦。每每想到不由心中一...

MySQL字符问题_mysql中字符串的位置

中文写入乱码问题:我输入的中文编码是urf8的,建的库是urf8的,但是插入mysql总是乱码,一堆"???????????????????????"我用的是ibatis,终于找到原因了,我是这么解决...

深圳尚学堂:mysql基本sql语句大全(三)

数据开发-经典1.按姓氏笔画排序:Select*FromTableNameOrderByCustomerNameCollateChinese_PRC_Stroke_ci_as//从少...

MySQL进行行级锁的?一会next-key锁,一会间隙锁,一会记录锁?

大家好,是不是很多人都对MySQL加行级锁的规则搞的迷迷糊糊,一会是next-key锁,一会是间隙锁,一会又是记录锁。坦白说,确实还挺复杂的,但是好在我找点了点规律,也知道如何如何用命令分析加...

一文讲清怎么利用Python Django实现Excel数据表的导入导出功能

摘要:Python作为一门简单易学且功能强大的编程语言,广受程序员、数据分析师和AI工程师的青睐。本文系统讲解了如何使用Python的Django框架结合openpyxl库实现Excel...

用DataX实现两个MySQL实例间的数据同步

DataXDataX使用Java实现。如果可以实现数据库实例之间准实时的...

MySQL数据库知识_mysql数据库基础知识

MySQL是一种关系型数据库管理系统;那废话不多说,直接上自己以前学习整理文档:查看数据库命令:(1).查看存储过程状态:showprocedurestatus;(2).显示系统变量:show...

如何为MySQL中的JSON字段设置索引

背景MySQL在2015年中发布的5.7.8版本中首次引入了JSON数据类型。自此,它成了一种逃离严格列定义的方式,可以存储各种形状和大小的JSON文档,例如审计日志、配置信息、第三方数据包、用户自定...

取消回复欢迎 发表评论: