TensorFlow2.0以上版本的图像分类
ztj100 2024-11-27 23:33 14 浏览 0 评论
摘要
本篇文章采用CNN实现图像的分类,图像选取了猫狗大战数据集的1万张图像(猫狗各5千)。模型采用自定义的CNN网络,版本是TensorFlow 2.0以上的版本。通过本篇文章,你可以学到图像分类常用的手段,包括:
1、图像增强
2、训练集和验证集切分
3、使用ModelCheckpoint保存最优模型
4、使用ReduceLROnPlateau调整学习率。
5、打印loss结果生成jpg图片。
网络详解
训练部分
1、导入依赖
import os
import numpy as np
from tensorflow import keras
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout,BatchNormalization,Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D,GlobalAveragePooling2D
import cv2
from tensorflow.keras.preprocessing.image import img_to_array
from sklearn.model_selection import train_test_split
from tensorflow.python.keras import Input
from tensorflow.python.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.python.keras.layers import PReLU, Activation
from tensorflow.python.keras.models import Model
2、设置全局参数
norm_size=100#输入到网络的图像尺寸,单位是像素。
datapath='train'#图片的根目录
EPOCHS =100#训练的epoch个数
INIT_LR = 1e-3#初始学习率
labelList=[]#标签
dicClass={'cat':0,'dog':1}#类别
labelnum=2#类别个数
batch_size = 4
3、加载数据
def loadImageData():
imageList = []
listImage=os.listdir(datapath)#获取所有的图像
for img in listImage:#遍历图像
labelName=dicClass[img.split('.')[0]]#获取label对应的数字
print(labelName)
labelList.append(labelName)
dataImgPath=os.path.join(datapath,img)
print(dataImgPath)
image = cv2.imdecode(np.fromfile(dataImgPath, dtype=np.uint8), -1)
# load the image, pre-process it, and store it in the data list
image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
image = img_to_array(image)
imageList.append(image)
imageList = np.array(imageList, dtype="int") / 255.0#归一化图像
return imageList
print("开始加载数据")
imageArr=loadImageData()
labelList = np.array(labelList)
print("加载数据完成")
print(labelList)
4、定义模型
def bn_prelu(x):
x = BatchNormalization(epsilon=1e-5)(x)
x = PReLU()(x)
return x
def build_model(out_dims, input_shape=(norm_size, norm_size, 3)):
inputs_dim = Input(input_shape)
x = Conv2D(32, (3, 3), strides=(2, 2), padding='same')(inputs_dim)
x = bn_prelu(x)
x = Conv2D(32, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = GlobalAveragePooling2D()(x)
dp_1 = Dropout(0.5)(x)
fc2 = Dense(out_dims)(dp_1)
fc2 = Activation('softmax')(fc2) #此处注意,为sigmoid函数
model = Model(inputs=inputs_dim, outputs=fc2)
return model
model=build_model(labelnum)#生成模型
optimizer = Adam(lr=INIT_LR)#加入优化器,设置优化器的学习率。
model.compile(optimizer =optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
5、切割训练集和验证集
trainX,valX,trainY,valY = train_test_split(imageArr,labelList, test_size=0.3, random_state=42)
6、数据增强
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)
val_datagen = ImageDataGenerator() #验证集不做图片增强
train_generator = train_datagen.flow(trainX,trainY,batch_size=batch_size,shuffle=True)
val_generator = val_datagen.flow(valX,valY,batch_size=batch_size,shuffle=True)
7、设置callback函数
checkpointer = ModelCheckpoint(filepath='weights_best_simple_model.hdf5',
monitor='val_accuracy',verbose=1, save_best_only=True, mode='max')
reduce = ReduceLROnPlateau(monitor='val_accuracy',patience=10,
verbose=1,
factor=0.5,
min_lr=1e-6)
8、训练并保存模型
history = model.fit_generator(train_generator,
steps_per_epoch=trainX.shape[0]/batch_size,
validation_data = val_generator,
epochs=EPOCHS,
validation_steps=valX.shape[0]/batch_size,
callbacks=[checkpointer,reduce],
verbose=1,shuffle=True)
model.save('my_model_.h5')
9、保存训练历史数据
import os
loss_trend_graph_path = r"WW_loss.jpg"
acc_trend_graph_path = r"WW_acc.jpg"
import matplotlib.pyplot as plt
print("Now,we start drawing the loss and acc trends graph...")
# summarize history for accuracy
fig = plt.figure(1)
plt.plot(history.history["accuracy"])
plt.plot(history.history["val_accuracy"])
plt.title("Model accuracy")
plt.ylabel("accuracy")
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(acc_trend_graph_path)
plt.close(1)
# summarize history for loss
fig = plt.figure(2)
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("Model loss")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(loss_trend_graph_path)
plt.close(2)
print("We are done, everything seems OK...")
# #windows系统设置10关机
os.system("shutdown -s -t 10")
img
img
完整代码: import os
import numpy as np
from tensorflow import keras
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout,BatchNormalization,Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D,GlobalAveragePooling2D
import cv2
from tensorflow.keras.preprocessing.image import img_to_array
from sklearn.model_selection import train_test_split
from tensorflow.python.keras import Input
from tensorflow.python.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from tensorflow.python.keras.layers import PReLU, Activation
from tensorflow.python.keras.models import Model
norm_size=100
datapath='train'
EPOCHS =100
INIT_LR = 1e-3
labelList=[]
dicClass={'cat':0,'dog':1}
labelnum=2
batch_size = 4
def loadImageData():
imageList = []
listImage=os.listdir(datapath)
for img in listImage:
labelName=dicClass[img.split('.')[0]]
print(labelName)
labelList.append(labelName)
dataImgPath=os.path.join(datapath,img)
print(dataImgPath)
image = cv2.imdecode(np.fromfile(dataImgPath, dtype=np.uint8), -1)
# load the image, pre-process it, and store it in the data list
image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
image = img_to_array(image)
imageList.append(image)
imageList = np.array(imageList, dtype="int") / 255.0
return imageList
print("开始加载数据")
imageArr=loadImageData()
labelList = np.array(labelList)
print("加载数据完成")
print(labelList)
def bn_prelu(x):
x = BatchNormalization(epsilon=1e-5)(x)
x = PReLU()(x)
return x
def build_model(out_dims, input_shape=(norm_size, norm_size, 3)):
inputs_dim = Input(input_shape)
x = Conv2D(32, (3, 3), strides=(2, 2), padding='same')(inputs_dim)
x = bn_prelu(x)
x = Conv2D(32, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = Conv2D(64, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = Conv2D(128, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = Conv2D(256, (3, 3), strides=(1, 1), padding='same')(x)
x = bn_prelu(x)
x = GlobalAveragePooling2D()(x)
dp_1 = Dropout(0.5)(x)
fc2 = Dense(out_dims)(dp_1)
fc2 = Activation('softmax')(fc2) #此处注意,为sigmoid函数
model = Model(inputs=inputs_dim, outputs=fc2)
return model
model=build_model(labelnum)
optimizer = Adam(lr=INIT_LR)
model.compile(optimizer =optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
trainX,valX,trainY,valY = train_test_split(imageArr,labelList, test_size=0.3, random_state=42)
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(featurewise_center=True,
featurewise_std_normalization=True,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True)
val_datagen = ImageDataGenerator() #验证集不做图片增强
train_generator = train_datagen.flow(trainX,trainY,batch_size=batch_size,shuffle=True)
val_generator = val_datagen.flow(valX,valY,batch_size=batch_size,shuffle=True)
checkpointer = ModelCheckpoint(filepath='weights_best_simple_model.hdf5',
monitor='val_accuracy',verbose=1, save_best_only=True, mode='max')
reduce = ReduceLROnPlateau(monitor='val_accuracy',patience=10,
verbose=1,
factor=0.5,
min_lr=1e-6)
history = model.fit_generator(train_generator,
steps_per_epoch=trainX.shape[0]/batch_size,
validation_data = val_generator,
epochs=EPOCHS,
validation_steps=valX.shape[0]/batch_size,
callbacks=[checkpointer,reduce],
verbose=1,shuffle=True)
model.save('my_model_.h5')
print(history)
import os
loss_trend_graph_path = r"WW_loss.jpg"
acc_trend_graph_path = r"WW_acc.jpg"
import matplotlib.pyplot as plt
print("Now,we start drawing the loss and acc trends graph...")
# summarize history for accuracy
fig = plt.figure(1)
plt.plot(history.history["accuracy"])
plt.plot(history.history["val_accuracy"])
plt.title("Model accuracy")
plt.ylabel("accuracy")
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(acc_trend_graph_path)
plt.close(1)
# summarize history for loss
fig = plt.figure(2)
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("Model loss")
plt.ylabel("loss")
plt.xlabel("epoch")
plt.legend(["train", "test"], loc="upper left")
plt.savefig(loss_trend_graph_path)
plt.close(2)
print("We are done, everything seems OK...")
# #windows系统设置10关机
os.system("shutdown -s -t 10")
测试部分 1、导入依赖 import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.models import load_model
import time
2、设置全局参数 norm_size=100
imagelist=[]
emotion_labels = {
0: 'cat',
1: 'dog'
}
3、加载模型 emotion_classifier=load_model(**"my_model_.h5"**)
t1=time.time()
4、处理图片 image = cv2.imdecode(np.fromfile(**'test/8.jpg'**, dtype=np.uint8), -1)
\# load the image, pre-process it, and store it in the data list
image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
image = img_to_array(image)
imagelist.append(image)
imageList = np.array(imagelist, dtype=**"float"**) / 255.0
5、预测类别 pre=np.argmax(emotion_classifier.predict(imageList))
emotion = emotion_labels[pre]
t2=time.time()
print(emotion)
t3=t2-t1
print(t3)
完整代码 import cv2
import numpy as np
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.models import load_model
import time
norm_size=100
imagelist=[]
emotion_labels = {
0: 'cat',
1: 'dog'
}
emotion_classifier=load_model("my_model_.h5")
t1=time.time()
image = cv2.imdecode(np.fromfile('test/8.jpg', dtype=np.uint8), -1)
# load the image, pre-process it, and store it in the data list
image = cv2.resize(image, (norm_size, norm_size), interpolation=cv2.INTER_LANCZOS4)
image = img_to_array(image)
imagelist.append(image)
imageList = np.array(imagelist, dtype="float") / 255.0
pre=np.argmax(emotion_classifier.predict(imageList))
emotion = emotion_labels[pre]
t2=time.time()
print(emotion)
t3=t2-t1
print(t3)
相关推荐
- 电脑装系统用GHOST好,还是原装版本好?老司机都是这么装的
-
Hello大家好,我是兼容机之家的咖啡。安装Windows系统是原版ISO好还是ghost好呢?针对这个的问题,我们先来科普一下什么是ghost系统,和原版ISO镜像两者之间有哪些优缺点。如果是很了解...
- 苹果 iOS 14.5.1/iPadOS 14.5.1 正式版发布
-
IT之家5月4日消息今日凌晨,苹果发布了iOS14.5.1与iPadOS14.5.1正式版更新。这一更新距iOS14.5正式版发布过去了一周时间。IT之家了解到,苹果表示,...
- iOS 13.1.3 正式版发布 包含错误修复和改进
-
苹果今天发布了iOS13.1.3和iPadOS13.1.3,这是iOS13发布之后第四个升级补丁。iOS13.1.2两周前发布。iOS13.1.3主要包括针对iPad和...
- 还不理解 Error 和 Exception 吗,看这篇就够了
-
在Java中的基本理念是结构不佳的代码不能运行,发现错误的理想时期是在编译期间,因为你不用运行程序,只是凭借着对Java基本理念的理解就能发现问题。但是编译期并不能找出所有的问题,有一些N...
- Linux 开发人员发现了导致 MacBook“无法启动”的 macOS 错误
-
“多个严重”错误影响配备ProMotion显示屏的MacBookPro。...
- 启动系统时无法正常启动提示\windows\system32\winload.efi
-
启动系统时无法正常启动提示\windows\system32\winload.efi。该怎么解决? 最近有用户遇到了开机遇到的问题,是Windows未能启动。原因可能是最近更改了硬件或软件。虽然提...
- 离线部署之两种构建Ragflow镜像的方式,dify同理
-
在实际项目交付过程中,经常遇到要离线部署的问题,生产服务器无法连接外网,这时就需要先构建好ragflow镜像,然后再拷到U盘或刻盘,下面介绍两种构建ragflow镜像的方式。性能测试(网络情况好的情况...
- Go语言 error 类型详解(go语言 异常)
-
Go语言的error类型是用于处理程序运行中错误情况的核心机制。它通过显式的返回值(而非异常抛出)来管理错误,强调代码的可控性和清晰性。以下是详细说明及示例:一、error类型的基本概念内置接口...
- Mac上“闪烁的问号”错误提示如何修复?
-
现在Mac电脑的用户越来越多,Mac电脑在使用过程中也会出现系统故障。当苹果电脑无法找到系统软件时,Mac会给出一个“闪烁的问号”的标志。很多用户受到过闪烁问号这一常见的错误提示的影响,如何解决这个问...
- python散装笔记——177 sys 模块(python sys模块详解)
-
sys模块提供了访问程序运行时环境的函数和值,例如命令行参数...
- 30天自制操作系统:第一天(30天自制操作系统电子书)
-
因为咱们的目的是为了研究操作系统的组成,所以直接从系统启动的第二阶段的主引导记录开始。前提是将编译工具放在该文件目录的同级目录下,该工具为日本人川合秀实自制的编译程序,优化过的nasm编译工具。...
- 五大原因建议您现在不要升级iOS 13或iPadOS
-
今天苹果放出了iPadOS和iOS13的公测版本,任何对新版功能感兴趣的用户都可以下载安装参与测试。除非你想要率先体验Dark模式,以及使用AppleID来登陆Facebook等服务,那么外媒CN...
- Python安装包总报错?这篇解决指南让你告别pip烦恼!
-
在Python开发中,...
- 苹果提供了在M1 Mac上修复macOS重装错误的方案
-
#AppleM1芯片#在苹果新的M1Mac推出后不久,我们看到有报道称,在这些机器上恢复和重新安装macOS,可能会导致安装错误,使你的Mac无法使用。具体来说,错误信息如下:"An...
- 黑苹果卡代码篇三:常见卡代码问题,满满的干货
-
前言...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- 电脑装系统用GHOST好,还是原装版本好?老司机都是这么装的
- 苹果 iOS 14.5.1/iPadOS 14.5.1 正式版发布
- iOS 13.1.3 正式版发布 包含错误修复和改进
- 还不理解 Error 和 Exception 吗,看这篇就够了
- Linux 开发人员发现了导致 MacBook“无法启动”的 macOS 错误
- 启动系统时无法正常启动提示\windows\system32\winload.efi
- 离线部署之两种构建Ragflow镜像的方式,dify同理
- Go语言 error 类型详解(go语言 异常)
- Mac上“闪烁的问号”错误提示如何修复?
- python散装笔记——177 sys 模块(python sys模块详解)
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- node卸载 (33)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- exceptionininitializererror (33)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)