百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

人工智能Keras CNN卷积神经网络的图片识别模型训练

ztj100 2024-12-28 16:50 27 浏览 0 评论

CNN卷积神经网络是人工智能的开端,CNN卷积神经网络让计算机能够认识图片,文字,甚至音频与视频。CNN卷积神经网络的基础知识,可以参考:CNN卷积神经网络

LetNet体系结构是卷积神经网络的“第一个图像分类器”。最初设计用于对手写数字进行分类,上期文章我们分享了如何使用keras来进行手写数字的神经网络搭建:Keras人工智能神经网络 Classifier 分类 神经网络搭建

我们也可以轻松地将其扩展到其他类型的图像上,本期使用小雪人的照片,来让神经网络识别雪人

雪人的图片大家可以到网络上自行下载,当然也可以使用爬虫技术来下载

搭建keras神经网络识别图片

from keras.models import Sequential
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.core import Activation
from keras.layers.core import Flatten
from keras.layers.core import Dense
from keras import backend as K

首先导入需要的模块,建立一个神经网络以便后期使用,在一个单独的文件中,命名此神经网络类(lenet.py)

class LeNet:
	@staticmethod
	def build(width, height, depth, classes):
		# 使用Sequential()初始化model
		model = Sequential()
		inputShape = (height, width, depth) 
		#tensorflow默认设置
    #宽度 :输入图像的宽度
    #高度 :输入图像的高度
    #深度 :输入图像中的频道数(1个 对于灰度单通道图像, 3 标准RGB图像)
    # 若是其他的(Theano),则使用((depth, height, width)
		if K.image_data_format() == "channels_first":
			inputShape = (depth, height, width)
		#建立卷积神经网络  =>然后是 RELU => 然后是max pooling(跟前期分享的tensorflow教程类似)
		model.add(Conv2D(20, (5, 5), padding="same",input_shape=inputShape))
		model.add(Activation("relu"))
		model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
		# 建立卷积神经网络  =>然后是 RELU => 然后是max pooling(第二层)
		model.add(Conv2D(50, (5, 5), padding="same"))
		model.add(Activation("relu"))
		model.add(MaxPooling2D(pool_size=(2, 2), strides=(2, 2)))
		# 增加全连接层
		model.add(Flatten())
		model.add(Dense(500))
		model.add(Activation("relu"))
		# softmax classifier 来进行神经网络的分类
		model.add(Dense(classes))
		model.add(Activation("softmax"))
		# return the model
		return model

训练keras神经网络

以上建立了keras 的神经网络模型,我们就使用预先下载好的图片来训练神经模型

建立一个train.py文件,插入如下代码,来训练神经网络模型(图片数据里面分成如下2类)

  1. snowman #我们训练的图片
  2. notsnowman 增加非雪人图片的训练
import matplotlib
matplotlib.use("Agg")
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from keras.preprocessing.image import img_to_array
from keras.utils import to_categorical
from lenet import LeNet
from imutils import paths
import matplotlib.pyplot as plt
import numpy as np
import random
import cv2
import os

初始化参数

EPOCHS = 25 #学习的步数
INIT_LR = 1e-3# 学习效率
BS = 32# 每步学习的个数
data = []# 存放图片数据
labels = []# 存放图片标签
imagePaths = sorted(list(paths.list_images("dataset\\")))# 遍历所有的图片
random.seed(42)
random.shuffle(imagePaths) # 打乱图片顺序

初始化参数完成后,需要把所有的图片加载,进行图片数据的整理

for imagePath in imagePaths:
    # 加载图片
    image = cv2.imread(imagePath)
    image = cv2.resize(image, (28, 28)) # resize 到28*28 LeNet所需的空间尺寸
    image = img_to_array(image) # 图片转换成array
    data.append(image) # 保存图片数据
    label = imagePath.split(os.path.sep)[-2] #获取图片标签
    label = 1 if label == "snowman" else 0
    labels.append(label) # 获取图片标签

预先处理图片

# 把图片数据变成【0.1】
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels)
# 设置测试数据与训练数据
#使用75%的数据将数据划分为训练和测试
#用于训练的数据,其余25%用于测试
(trainX, testX, trainY, testY) = train_test_split(data, labels, test_size=0.25, random_state=42)
# 标签转换成向量
trainY = to_categorical(trainY, num_classes=2)
testY = to_categorical(testY, num_classes=2)
# 创建一个图像生成器对象,该对象在图像数据集上执行随机旋转,平移,翻转,修剪和剪切。
#这使我们可以使用较小的数据集,但仍然可以获得较高的结果
aug = ImageDataGenerator(rotation_range=30, width_shift_range=0.1,
                         height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,
                         horizontal_flip=True, fill_mode="nearest")

建立神经网络,进行神经网络训练

#建立model
model = LeNet.build(width=28, height=28, depth=3, classes=2)
opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)
model.compile(loss="binary_crossentropy", optimizer=opt, metrics=["accuracy"])
# 训练神经网络
H = model.fit_generator(aug.flow(trainX, trainY, batch_size=BS),
                        validation_data=(testX, testY), steps_per_epoch=len(trainX) // BS,
                        epochs=EPOCHS, verbose=1)

神经网络训练完成后,对神经网络训练的结果进行保存,以便后期使用预训练模型进行图片识别

保存模型,显示训练结果

model.save("lenet.model") # 保存模型
# 显示结果
plt.style.use("ggplot")
plt.figure()
N = EPOCHS
plt.plot(np.arange(0, N), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy on snowman/Notsnowman")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="lower left")
plt.savefig("plot1.JPG")

从训练结果可以看出,loss越来越小,精度越来越高,表明我们的神经网络模型是完全ok。

若想得到更好的训练数据,当然是使用大量的数据进行训练

以上便是我们训练的神经网络模型,下期我们使用预训练模型,对图片进行识别


相关推荐

SpringBoot如何实现优雅的参数校验
SpringBoot如何实现优雅的参数校验

平常业务中肯定少不了校验,如果我们把大量的校验代码夹杂到业务中,肯定是不优雅的,对于一些简单的校验,我们可以使用java为我们提供的api进行处理,同时对于一些...

2025-05-11 19:46 ztj100

Java中的空指针怎么处理?

#暑期创作大赛#Java程序员工作中遇到最多的错误就是空指针异常,无论你多么细心,一不留神就从代码的某个地方冒出NullPointerException,令人头疼。...

一坨一坨 if/else 参数校验,被 SpringBoot 参数校验组件整干净了

来源:https://mp.weixin.qq.com/s/ZVOiT-_C3f-g7aj3760Q-g...

用了这两款插件,同事再也不说我代码写的烂了

同事:你的代码写的不行啊,不够规范啊。我:我写的代码怎么可能不规范,不要胡说。于是同事打开我的IDEA,安装了一个插件,然后执行了一下,规范不规范,看报告吧。这可怎么是好,这玩意竟然给我挑出来这么...

SpringBoot中6种拦截器使用场景

SpringBoot中6种拦截器使用场景,下面是思维导图详细总结一、拦截器基础...

用注解进行参数校验,spring validation介绍、使用、实现原理分析

springvalidation是什么在平时的需求开发中,经常会有参数校验的需求,比如一个接收用户注册请求的接口,要校验用户传入的用户名不能为空、用户名长度不超过20个字符、传入的手机号是合法的手机...

快速上手:SpringBoot自定义请求参数校验

作者:UncleChen来源:http://unclechen.github.io/最近在工作中遇到写一些API,这些API的请求参数非常多,嵌套也非常复杂,如果参数的校验代码全部都手动去实现,写起来...

分布式微服务架构组件

1、服务发现-Nacos服务发现、配置管理、服务治理及管理,同类产品还有ZooKeeper、Eureka、Consulhttps://nacos.io/zh-cn/docs/what-is-nacos...

优雅的参数校验,告别冗余if-else

一、参数校验简介...

Spring Boot断言深度指南:用断言机制为代码构筑健壮防线

在SpringBoot开发中,断言(Assert)如同代码的"体检医生",能在上线前精准捕捉业务逻辑漏洞。本文将结合企业级实践,解析如何通过断言机制实现代码自检、异常预警与性能优化三...

如何在项目中优雅的校验参数

本文看点前言验证数据是贯穿所有应用程序层(从表示层到持久层)的常见任务。通常在每一层实现相同的验证逻辑,这既费时又容易出错。为了避免重复这些验证,开发人员经常将验证逻辑直接捆绑到域模型中,将域类与验证...

SpingBoot项目使用@Validated和@Valid参数校验

一、什么是参数校验?我们在后端开发中,经常遇到的一个问题就是入参校验。简单来说就是对一个方法入参的参数进行校验,看是否符合我们的要求。比如入参要求是一个金额,你前端没做限制,用户随便过来一个负数,或者...

28个验证注解,通过业务案例让你精通Java数据校验(收藏篇)

在现代软件开发中,数据验证是确保应用程序健壮性和可靠性的关键环节。JavaBeanValidation(JSR380)作为一个功能强大的规范,为我们提供了一套全面的注解工具集,这些注解能够帮...

Springboot @NotBlank参数校验失效汇总

有时候明明一个微服务里的@Validated和@NotBlank用的好好的,但就是另一个里不能用,这时候问题是最不好排查的,下面列举了各种失效情况的汇总,供各位参考:1、版本问题springbo...

这可能是最全面的Spring面试八股文了

Spring是什么?Spring是一个轻量级的控制反转(IoC)和面向切面(AOP)的容器框架。...

取消回复欢迎 发表评论: