keras 人工智能之VGGNet神经网络模型训练
ztj100 2024-12-28 16:50 24 浏览 0 评论
上期文章我们分享了如何使用LetNet体系结构来搭建一个图片识别的神经网络:
人工智能Keras的第一个图像分类器(CNN卷积神经网络的图片识别)
本期我们基于VGGNet神经网络来进行图片的识别,且增加图片的识别种类,当然你也可以增加更多的种类,本期代码跟往期代码有很大的相识处,可以参考
VGGNet基础
- 输入是大小为224*224的RGB图像,预处理(preprocession)时计算出三个通道的平均值,在每个像素上减去平均值
- 图像经过一系列卷积层处理,在卷积层中使用了非常小的3*3卷积核,在有些卷积层里则使用了1*1的卷积核。
- 卷积层步长(stride)设置为1个像素,3*3卷积层的填充(padding)设置为1个像素。池化层采用max pooling,共有5层,在一部分卷积层后,max-pooling的窗口是2*2,步长设置为2。
- 卷积层之后是三个全连接层(fully-connected layers,FC)。前两个全连接层均有4096个通道,第三个全连接层有1000个通道,用来分类。所有网络的全连接层配置相同。
- 全连接层后是Softmax,用来分类。
- 所有隐藏层(每个conv层中间)都使用ReLU作为激活函数。VGGNet不使用局部响应标准化(LRN),这种标准化并不能在ILSVRC数据集上提升性能,却导致更多的内存消耗和计算时间(LRN:Local Response Normalization,局部响应归一化,用于增强网络的泛化能力)。
VGGNet keras 神经网络搭建
使用VGGNet基础知识,我们使用keras来搭建一个小型的神经网络,新建一个smallervggnet.py文件
from keras.models import Sequential
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import Conv2D
from keras.layers.convolutional import MaxPooling2D
from keras.layers.core import Activation
from keras.layers.core import Flatten
from keras.layers.core import Dropout
from keras.layers.core import Dense
from keras import backend as K
class SmallerVGGNet:
@staticmethod
def build(width, height, depth, classes):
# initialize the model along with the input shape to be
# "channels last" and the channels dimension itself
model = Sequential()
inputShape = (height, width, depth)
chanDim = -1
# if we are using "channels first", update the input shape
# and channels dimension
if K.image_data_format() == "channels_first":
inputShape = (depth, height, width)
chanDim = 1
# CONV => RELU => POOL
model.add(Conv2D(32, (3, 3), padding="same",
input_shape=inputShape))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(MaxPooling2D(pool_size=(3, 3)))
model.add(Dropout(0.25))
# (CONV => RELU) * 2 => POOL
model.add(Conv2D(64, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(Conv2D(64, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
# (CONV => RELU) * 2 => POOL
model.add(Conv2D(128, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(Conv2D(128, (3, 3), padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(axis=chanDim))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
# first (and only) set of FC => RELU layers
model.add(Flatten())
model.add(Dense(1024))
model.add(Activation("relu"))
model.add(BatchNormalization())
model.add(Dropout(0.5))
# softmax classifier
model.add(Dense(classes))
model.add(Activation("softmax"))
# return the constructed network architecture
return model
搭建图片识别训练模型
导入第三方库
import matplotlib
matplotlib.use("Agg")
from keras.preprocessing.image import ImageDataGenerator
from keras.optimizers import Adam
from keras.preprocessing.image import img_to_array
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from smallervggnet import SmallerVGGNet
from keras.utils import to_categorical
import matplotlib.pyplot as plt
from imutils import paths
import numpy as np
import random
import pickle
import cv2
import os
初始化数据
EPOCHS = 100 #学习的步数
INIT_LR = 1e-3 #学习效率
BS = 32# 每步学习个数
IMAGE_DIMS = (96, 96, 3) # 图片尺寸
data = [] # 保存图片数据
labels = [] # 保存图片label
# 加载所有图片
imagePaths = sorted(list(paths.list_images("dataset\\")))
random.seed(42)
random.shuffle(imagePaths)
遍历图片搜集图片信息
for imagePath in imagePaths:
# 加载所有图片
image = cv2.imread(imagePath)
image = cv2.resize(image, (IMAGE_DIMS[1], IMAGE_DIMS[0]))
image = img_to_array(image)
data.append(image)
# 搜集图片data 与label
label = imagePath.split(os.path.sep)[-2]
print(label)
labels.append(label)
处理图片
# 处理数据到0-1
data = np.array(data, dtype="float") / 255.0
labels = np.array(labels
# 标签二值化
lb = LabelBinarizer()
labels = lb.fit_transform(labels)
#labels = to_categorical(labels) #多类删除这个,当然本期代码完全可以使用在介绍lenet网络上
搭建神经网络模型
(trainX, testX, trainY, testY) = train_test_split(data,
labels, test_size=0.2, random_state=42)
#分开测试数据
#创建一个图像生成器对象,该对象在图像数据集上执行随机旋转,平移,翻转,修剪和剪切。
#这使我们可以使用较小的数据集,但仍然可以获得较高的结果
aug = ImageDataGenerator(rotation_range=25, width_shift_range=0.1,
height_shift_range=0.1, shear_range=0.2, zoom_range=0.2,
horizontal_flip=True, fill_mode="nearest")
# 初始化模型
model = SmallerVGGNet.build(width=IMAGE_DIMS[1], height=IMAGE_DIMS[0],
depth=IMAGE_DIMS[2], classes=len(lb.classes_))
opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)
model.compile(loss="categorical_crossentropy", optimizer=opt,
metrics=["accuracy"])
训练神经网络
H = model.fit_generator(
aug.flow(trainX, trainY, batch_size=BS),
validation_data=(testX, testY),
steps_per_epoch=len(trainX) // BS,
epochs=EPOCHS, verbose=1)
保存训练模型
model.save("VGGNet.model")
f = open("labelbin.pickle", "wb")
f.write(pickle.dumps(lb))
f.close()
显示训练结果
plt.style.use("ggplot")
plt.figure()
N = EPOCHS
plt.plot(np.arange(0, N), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, N), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, N), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, N), H.history["val_acc"], label="val_acc")
plt.title("Training Loss and Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend(loc="upper left")
plt.savefig("plot1.JPG")
下期我们将使用预训练好的模型对图片进行识别
相关推荐
-
- SpringBoot如何实现优雅的参数校验
-
平常业务中肯定少不了校验,如果我们把大量的校验代码夹杂到业务中,肯定是不优雅的,对于一些简单的校验,我们可以使用java为我们提供的api进行处理,同时对于一些...
-
2025-05-11 19:46 ztj100
- Java中的空指针怎么处理?
-
#暑期创作大赛#Java程序员工作中遇到最多的错误就是空指针异常,无论你多么细心,一不留神就从代码的某个地方冒出NullPointerException,令人头疼。...
- 一坨一坨 if/else 参数校验,被 SpringBoot 参数校验组件整干净了
-
来源:https://mp.weixin.qq.com/s/ZVOiT-_C3f-g7aj3760Q-g...
- 用了这两款插件,同事再也不说我代码写的烂了
-
同事:你的代码写的不行啊,不够规范啊。我:我写的代码怎么可能不规范,不要胡说。于是同事打开我的IDEA,安装了一个插件,然后执行了一下,规范不规范,看报告吧。这可怎么是好,这玩意竟然给我挑出来这么...
- SpringBoot中6种拦截器使用场景
-
SpringBoot中6种拦截器使用场景,下面是思维导图详细总结一、拦截器基础...
- 用注解进行参数校验,spring validation介绍、使用、实现原理分析
-
springvalidation是什么在平时的需求开发中,经常会有参数校验的需求,比如一个接收用户注册请求的接口,要校验用户传入的用户名不能为空、用户名长度不超过20个字符、传入的手机号是合法的手机...
- 快速上手:SpringBoot自定义请求参数校验
-
作者:UncleChen来源:http://unclechen.github.io/最近在工作中遇到写一些API,这些API的请求参数非常多,嵌套也非常复杂,如果参数的校验代码全部都手动去实现,写起来...
- 分布式微服务架构组件
-
1、服务发现-Nacos服务发现、配置管理、服务治理及管理,同类产品还有ZooKeeper、Eureka、Consulhttps://nacos.io/zh-cn/docs/what-is-nacos...
- 优雅的参数校验,告别冗余if-else
-
一、参数校验简介...
- Spring Boot断言深度指南:用断言机制为代码构筑健壮防线
-
在SpringBoot开发中,断言(Assert)如同代码的"体检医生",能在上线前精准捕捉业务逻辑漏洞。本文将结合企业级实践,解析如何通过断言机制实现代码自检、异常预警与性能优化三...
- 如何在项目中优雅的校验参数
-
本文看点前言验证数据是贯穿所有应用程序层(从表示层到持久层)的常见任务。通常在每一层实现相同的验证逻辑,这既费时又容易出错。为了避免重复这些验证,开发人员经常将验证逻辑直接捆绑到域模型中,将域类与验证...
- SpingBoot项目使用@Validated和@Valid参数校验
-
一、什么是参数校验?我们在后端开发中,经常遇到的一个问题就是入参校验。简单来说就是对一个方法入参的参数进行校验,看是否符合我们的要求。比如入参要求是一个金额,你前端没做限制,用户随便过来一个负数,或者...
- 28个验证注解,通过业务案例让你精通Java数据校验(收藏篇)
-
在现代软件开发中,数据验证是确保应用程序健壮性和可靠性的关键环节。JavaBeanValidation(JSR380)作为一个功能强大的规范,为我们提供了一套全面的注解工具集,这些注解能够帮...
- Springboot @NotBlank参数校验失效汇总
-
有时候明明一个微服务里的@Validated和@NotBlank用的好好的,但就是另一个里不能用,这时候问题是最不好排查的,下面列举了各种失效情况的汇总,供各位参考:1、版本问题springbo...
- 这可能是最全面的Spring面试八股文了
-
Spring是什么?Spring是一个轻量级的控制反转(IoC)和面向切面(AOP)的容器框架。...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- node卸载 (33)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- exceptionininitializererror (33)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)