百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

TensorFlow2学习十二、使用预训练CNN进行迁移学习识别猫和狗

ztj100 2024-12-28 16:50 19 浏览 0 评论

一、说明

本文学习资源来自tensorflow官网,测试环境使用tensor conlab。

1. 本文内容

学习怎么使用预训练cnn进行迁移学习从而把猫、狗分类。
预训练模型是一个使用大量数据训练好并保存好的网络模型,典型的是大量图像数据的分类工作。我们可以使用本文中的预训练模型,也可以针对一个任务使用迁移学习客制化模型。
当一个模型是基于足够大的、足够有代表性的数据集训练出来的,那么它可以有效的工作在机器视觉中。我们可以利用这些学习好的特征,而不用再基于大量数据集重复进行训练。
本文使用两种方式客制化预训练模型:

  1. 特征展开:使用之前模型从新的数据集中提取有用的特征。只需要在预训练模型顶部简单的添加分类器,不需要重新训练整个模型。基础cnn已经包含了有用的分类特征。
  2. 微调:从一个冻结模型解除顶部的层,连接新添加的分类层和后面的基础模型。

2. 步骤

  • 理解数据
  • 导入包和数据,使用Keras ImageDataGenerator处理数据
  • 构建模型
  • 加载预训练模型(和预置权重)
  • 把新的分类层堆叠到顶部
  • 训练模型
  • 评估模型

二、实现

1. 引入包

from __future__ import absolute_import, division, print_function, unicode_literals
import os
import numpy as np
import matplotlib.pyplot as plt
try:
 # %tensorflow_version only exists in Colab.
 %tensorflow_version 2.x
except Exception:
 pass
import tensorflow as tf

keras = tf.keras

2. 数据处理

使用tensorflow_datasets加载数据

import tensorflow_datasets as tfds
tfds.disable_progress_bar()
# 训练集 验证集 测试集 比例8:1:1
SPLIT_WEIGHTS = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)

(raw_train, raw_validation, raw_test), metadata = tfds.load(
 'cats_vs_dogs', split=list(splits),
 with_info=True, as_supervised=True)

print(raw_train)
print(raw_validation)
print(raw_test)

显示前2个图片看看

get_label_name = metadata.features['label'].int2str

for image, label in raw_train.take(2):
 plt.figure()
 plt.imshow(image)
 plt.title(get_label_name(label))

格式化数据

使用 tf.image 格式化数据。 图片缩放到160*160,输入通道值转换为[-1,1]

IMG_SIZE = 160 # All images will be resized to 160x160

def format_example(image, label):
 image = tf.cast(image, tf.float32)
 image = (image/127.5) - 1
 image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
 return image, label

# 数据集中每一项都使用该函数处理
train = raw_train.map(format_example)
validation = raw_validation.map(format_example)
test = raw_test.map(format_example)

数据集分批

BATCH_SIZE = 32
SHUFFLE_BUFFER_SIZE = 1000
train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
validation_batches = validation.batch(BATCH_SIZE)
test_batches = test.batch(BATCH_SIZE)
for image_batch, label_batch in train_batches.take(1):
 pass

image_batch.shape

3. 从预训练cnn里创建基础模型

下面加载google创建的 MobileNet 模型,这个模型基于ImageNet,由李飞飞团队创建的大型数据集。这个数据集有1400多万数据和超过2万多个标注,与超过百万的边界框标注。这个模型有助于从我们的数据集里分辨狗和猫。
现在要确定模型哪个层用于特征提取。大多数机器学习模型是从下往上的,最后一个分类器(顶部)不是很有用。
我们将遵循常规做法,转而依赖展平操作之前的最后一层。这个层被称为“瓶颈层”。与最终/顶层相比,瓶颈层特性保留了许多通用性。

首先,实例化一个MobileNet V2模型,该模型预先加载了在ImageNet上训练的权重。
通过指定include_top=False参数,可以加载一个不包括顶部分类层的网络,这是特征提取的理想选择。

IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3)

# Create the base model from the pre-trained model MobileNet V2
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
 include_top=False,
 weights='imagenet')

feature_batch = base_model(image_batch)
print(feature_batch.shape)


特征提取操作将每个160x160x3 转换成 5x5x1280 特征块.

# 冻结层,防止在训练期间更新给定层中的权重
base_model.trainable = False
base_model.summary()

4. 全局池化层GAP

GAP被认为是可以替代全连接层的一种新技术。在keras发布的经典模型中,可以看到不少模型甚至抛弃了全连接层,转而使用GAP。
在支持迁移学习方面,各个模型几乎都支持使用Global Average Pooling和Global Max Pooling(GMP)。
这里使用tf.keras.layers.GlobalAveragePooling2D将每个图片以5x5空间位置转成单个1280元素向量。

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
# tf.keras.layers.Dense 层将每个特征输入预测器,预测器训练使用logit,所以这里不需要指定激活函数。
# 正向数值预测分类1,负向数值预测分类0
prediction_layer = keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)

5. 创建模型

model = tf.keras.Sequential([
 base_model,
 global_average_layer,
 prediction_layer
])

6. 编译模型

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate),
 loss='binary_crossentropy',
 metrics=['accuracy'])
model.summary()
len(model.trainable_variables) # 2

2.5M MobileNet的变量冻结了,但还有1.2k全连接层还有2.5M训练参数,被分成权重和偏置两个tf变量对象.

7. 训练模型

num_train, num_val, num_test = (
 metadata.splits['train'].num_examples*weight/10
 for weight in SPLIT_WEIGHTS
)
initial_epochs = 10
steps_per_epoch = round(num_train)//BATCH_SIZE
validation_steps = 20

loss0,accuracy0 = model.evaluate(validation_batches, steps = validation_steps)
print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))

8. 拟合

history = model.fit(train_batches,
 epochs=initial_epochs,
 validation_data=validation_batches)

9. 学习曲线可视化

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()


这里验证指标明显优于培训指标,主要原因是像tf.keras.layers.BatchNormalization和tf.keras.layers.Dropout这样的层会影响训练期间的准确性。它们在计算验证丢失时被关闭。

在较小程度上,这也是因为训练度量报告了一个epoch的平均值,而验证度量在epoch之后进行评估,因此验证度量看到的模型训练的时间稍长。

三、微调模型

上面的特征提取操作只在MobileNet V2基础模型上训练了几层。训练过程中未更新训练网络的权值。

进一步提高性能的一种方法是在训练添加的分类器的同时训练(或“微调”)预训练模型顶层的权重。训练过程将强制将权重从通用特征映射调整到与我们的数据集特定关联的特征。
注意:只有在将预先训练的模型设置为不可训练的顶级分类器之后,才能尝试此操作。
如果在预先训练的模型上添加一个随机初始化的分类器并尝试联合训练所有层,则梯度更新的幅度将过大(由于来自分类器的随机权重),并且预先训练的模型将忘记它所学到的内容。

此外应该尝试微调少数顶层,而不是整个MobileNet模型。在大多数卷积网络中,一层越高,它就越专业。前几层学习非常简单和通用的特征,这些特征可以概括为几乎所有类型的图像。当你往上走的时候,这些特性对模型所训练的数据集越来越具体。微调的目标是使这些专门特性适应新数据集,而不是覆盖泛型学习。

base_model.trainable = True
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))

# Fine tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
 layer.trainable = False

model.compile(loss='binary_crossentropy',
 optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
 metrics=['accuracy'])

fine_tune_epochs = 10
total_epochs = initial_epochs + fine_tune_epochs

history_fine = model.fit(train_batches,
 epochs=total_epochs,
 initial_epoch = history.epoch[-1],
 validation_data=validation_batches)

# 微调后准确率可以达到98%左右

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
 plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
 plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

总结

使用预先训练的模型进行特征提取:

  • 在处理小数据集时,通常会利用在同一域中的较大数据集上训练的模型所学习的特征。这是通过实例化预先训练的模型并在上面添加一个完全连接的分类器来完成的。训练前的模型被冻结,训练过程中只更新分类器的权值。在这种情况下,卷积基提取了与每个图像相关联的所有特征,只需训练一个分类器,该分类器根据提取的特征集确定图像类别。

微调预先训练的模型

为了进一步提高性能,可能需要通过微调将预先训练的模型的顶层重新调整到新的数据集。在本例中,我们调整了权重,以便模型学习特定于数据集的高级特性。当训练数据集很大且与训练前模型所用的原始数据集非常相似时,通常建议使用此技术。

相关推荐

干货 | 各大船公司VGM提交流程(msc船运公司提单查询)

VGM(VerifiedGrossMass)要来了,大外总管一本正经来给大家分享下各大船公司提交VGM流程。1,赫伯罗特(简称HPL)首先要注册账户第一,登录进入—选择product------...

如何修改图片详细信息?分享三个简单方法

如何修改图片详细信息?分享三个简单方法我们知道图片的详细信息里面包含了很多属性,有图片的创建时间,修改时间,地理位置,拍摄时间,还有图片的描述等信息。有时候为了一些特殊场景的需要我们需要对这些信息进行...

实用方法分享:没有图像处理软件,怎么将一张照片做成九宫格?

在发朋友圈时,如果把自己的照片做成九宫格,是不是更显得高大上?可能你问,是不是要借助图片处理软件,在这里,我肯定告诉你,不需要!!!你可能要问,那怎么实现呢?下面你看我是怎么做的,一句代码都不写,只是...

扫描档PDF也能变身“最强大脑”?RAG技术解锁尘封的知识宝藏!

尊敬的诸位!我是一名物联网工程师。关注我,持续分享最新物联网与AI资讯和开发实战。期望与您携手探寻物联网与AI的无尽可能。今天有网友问我扫描档的PDF文件能否做知识库,其实和普通pdf处理起来差异...

这两个Python库,轻而易举就能实现MP4与GIF格式互转,太好用了

mp4转gif的原理其实很简单,就是将mp4文件的帧读出来,然后合并成一张gif图。用cv2和PIL这两个库就可以轻松搞定。importglobimportcv2fromPILimpo...

python图片处理之图片切割(python把图片切割成固定大小的子图)

python图片切割在很多项目中都会用到,比如验证码的识别、目标检测、定点切割等,本文给大家带来python的两种切割方式:fromPILimportImage"""...

python+selenium+pytesseract识别图片验证码

一、selenium截取验证码#私信小编01即可获取大量Python学习资源#私信小编01即可获取大量Python学习资源#私信小编01即可获取大量Python学习资源importjso...

如何使用python裁剪图片?(python图片截取)

如何使用python裁剪图片如上图所示,这是一张包含了各类象棋棋子的图片。我们需要将其中每一个棋子都裁剪出来,此时可以利用python的...

Python rembg 库去除图片背景(python 删除图片)

rembg是一个强大的Python库,用于自动去除图片背景。它基于深度学习模型(如U^2-Net),能够高效地将前景物体从背景中分离,生成透明背景的PNG图像。本教程将带你从安装到实际应用...

「python脚本」批量修改图片尺寸&视频安帧提取

【python脚本】批量修改图片尺寸#-*-coding:utf-8-*-"""CreatedonThuAug2316:06:352018@autho...

有趣的EXCEL&vba作图(vba画图表)

还记不记得之前有个日本老爷爷用EXCEL绘图,美轮美奂,可谓是心思巧妙。我是没有那样的艺术细胞,不过咱有自己的方式,用代码作图通过vba代码将指定的图片写入excel工作表中,可不是插入图片哦解题思...

怎么做到的?用python制作九宫格图片,太棒了

1.应用场景当初的想法是:想把一张图切割成九等份,发布到微信朋友圈,切割出来的图片,上传到朋友圈,发现微信不按照我排列的序号来排版。这样的结果是很耗时间的。让我深思,能不能有一种,直接拼接成一张...

Python-连续图片合成视频(python多张图叠加为一张)

前言很多时候,我们需要将图片直接转成视频。下面介绍用python中的OpenCV将进行多张图合成视频。cv2安装不要直接用pipinstallcv2,这会报错。有很多人建议用打开window自带的...

如何把多个文件夹里的图片提取出来?文件夹整理合并工具

在项目管理中,团队成员可能会将项目相关的图片资料分散存储在不同的文件夹中,以便于分类和阶段性管理。然而,当项目进入汇报或总结阶段时,需要将所有相关图片整合到一个位置,以便于制作演示文稿、报告或进行项目...

超简单!为图片和 PDF 上去掉水印(pdf图片和水印是一体,怎么去除)

作者:某某白米饭...

取消回复欢迎 发表评论: