TensorFlow2学习十二、使用预训练CNN进行迁移学习识别猫和狗
ztj100 2024-12-28 16:50 42 浏览 0 评论
一、说明
本文学习资源来自tensorflow官网,测试环境使用tensor conlab。
1. 本文内容
学习怎么使用预训练cnn进行迁移学习从而把猫、狗分类。
预训练模型是一个使用大量数据训练好并保存好的网络模型,典型的是大量图像数据的分类工作。我们可以使用本文中的预训练模型,也可以针对一个任务使用迁移学习客制化模型。
当一个模型是基于足够大的、足够有代表性的数据集训练出来的,那么它可以有效的工作在机器视觉中。我们可以利用这些学习好的特征,而不用再基于大量数据集重复进行训练。
本文使用两种方式客制化预训练模型:
- 特征展开:使用之前模型从新的数据集中提取有用的特征。只需要在预训练模型顶部简单的添加分类器,不需要重新训练整个模型。基础cnn已经包含了有用的分类特征。
- 微调:从一个冻结模型解除顶部的层,连接新添加的分类层和后面的基础模型。
2. 步骤
- 理解数据
- 导入包和数据,使用Keras ImageDataGenerator处理数据
- 构建模型
- 加载预训练模型(和预置权重)
- 把新的分类层堆叠到顶部
- 训练模型
- 评估模型
二、实现
1. 引入包
from __future__ import absolute_import, division, print_function, unicode_literals import os import numpy as np import matplotlib.pyplot as plt try: # %tensorflow_version only exists in Colab. %tensorflow_version 2.x except Exception: pass import tensorflow as tf keras = tf.keras
2. 数据处理
使用tensorflow_datasets加载数据
import tensorflow_datasets as tfds tfds.disable_progress_bar() # 训练集 验证集 测试集 比例8:1:1 SPLIT_WEIGHTS = (8, 1, 1) splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS) (raw_train, raw_validation, raw_test), metadata = tfds.load( 'cats_vs_dogs', split=list(splits), with_info=True, as_supervised=True) print(raw_train) print(raw_validation) print(raw_test)
显示前2个图片看看
get_label_name = metadata.features['label'].int2str for image, label in raw_train.take(2): plt.figure() plt.imshow(image) plt.title(get_label_name(label))
格式化数据
使用 tf.image 格式化数据。 图片缩放到160*160,输入通道值转换为[-1,1]
IMG_SIZE = 160 # All images will be resized to 160x160 def format_example(image, label): image = tf.cast(image, tf.float32) image = (image/127.5) - 1 image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE)) return image, label # 数据集中每一项都使用该函数处理 train = raw_train.map(format_example) validation = raw_validation.map(format_example) test = raw_test.map(format_example)
数据集分批
BATCH_SIZE = 32 SHUFFLE_BUFFER_SIZE = 1000 train_batches = train.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE) validation_batches = validation.batch(BATCH_SIZE) test_batches = test.batch(BATCH_SIZE) for image_batch, label_batch in train_batches.take(1): pass image_batch.shape
3. 从预训练cnn里创建基础模型
下面加载google创建的 MobileNet 模型,这个模型基于ImageNet,由李飞飞团队创建的大型数据集。这个数据集有1400多万数据和超过2万多个标注,与超过百万的边界框标注。这个模型有助于从我们的数据集里分辨狗和猫。
现在要确定模型哪个层用于特征提取。大多数机器学习模型是从下往上的,最后一个分类器(顶部)不是很有用。
我们将遵循常规做法,转而依赖展平操作之前的最后一层。这个层被称为“瓶颈层”。与最终/顶层相比,瓶颈层特性保留了许多通用性。
首先,实例化一个MobileNet V2模型,该模型预先加载了在ImageNet上训练的权重。
通过指定include_top=False参数,可以加载一个不包括顶部分类层的网络,这是特征提取的理想选择。
IMG_SHAPE = (IMG_SIZE, IMG_SIZE, 3) # Create the base model from the pre-trained model MobileNet V2 base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE, include_top=False, weights='imagenet') feature_batch = base_model(image_batch) print(feature_batch.shape)
特征提取操作将每个160x160x3 转换成 5x5x1280 特征块.
# 冻结层,防止在训练期间更新给定层中的权重 base_model.trainable = False base_model.summary()
4. 全局池化层GAP
GAP被认为是可以替代全连接层的一种新技术。在keras发布的经典模型中,可以看到不少模型甚至抛弃了全连接层,转而使用GAP。
在支持迁移学习方面,各个模型几乎都支持使用Global Average Pooling和Global Max Pooling(GMP)。
这里使用tf.keras.layers.GlobalAveragePooling2D将每个图片以5x5空间位置转成单个1280元素向量。
global_average_layer = tf.keras.layers.GlobalAveragePooling2D() feature_batch_average = global_average_layer(feature_batch) print(feature_batch_average.shape)
# tf.keras.layers.Dense 层将每个特征输入预测器,预测器训练使用logit,所以这里不需要指定激活函数。 # 正向数值预测分类1,负向数值预测分类0 prediction_layer = keras.layers.Dense(1) prediction_batch = prediction_layer(feature_batch_average) print(prediction_batch.shape)
5. 创建模型
model = tf.keras.Sequential([ base_model, global_average_layer, prediction_layer ])
6. 编译模型
base_learning_rate = 0.0001 model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=base_learning_rate), loss='binary_crossentropy', metrics=['accuracy']) model.summary() len(model.trainable_variables) # 2
2.5M MobileNet的变量冻结了,但还有1.2k全连接层还有2.5M训练参数,被分成权重和偏置两个tf变量对象.
7. 训练模型
num_train, num_val, num_test = ( metadata.splits['train'].num_examples*weight/10 for weight in SPLIT_WEIGHTS ) initial_epochs = 10 steps_per_epoch = round(num_train)//BATCH_SIZE validation_steps = 20 loss0,accuracy0 = model.evaluate(validation_batches, steps = validation_steps)
print("initial loss: {:.2f}".format(loss0)) print("initial accuracy: {:.2f}".format(accuracy0))
8. 拟合
history = model.fit(train_batches, epochs=initial_epochs, validation_data=validation_batches)
9. 学习曲线可视化
acc = history.history['accuracy'] val_acc = history.history['val_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss'] plt.figure(figsize=(8, 8)) plt.subplot(2, 1, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.legend(loc='lower right') plt.ylabel('Accuracy') plt.ylim([min(plt.ylim()),1]) plt.title('Training and Validation Accuracy') plt.subplot(2, 1, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.legend(loc='upper right') plt.ylabel('Cross Entropy') plt.ylim([0,1.0]) plt.title('Training and Validation Loss') plt.xlabel('epoch') plt.show()
这里验证指标明显优于培训指标,主要原因是像tf.keras.layers.BatchNormalization和tf.keras.layers.Dropout这样的层会影响训练期间的准确性。它们在计算验证丢失时被关闭。
在较小程度上,这也是因为训练度量报告了一个epoch的平均值,而验证度量在epoch之后进行评估,因此验证度量看到的模型训练的时间稍长。
三、微调模型
上面的特征提取操作只在MobileNet V2基础模型上训练了几层。训练过程中未更新训练网络的权值。
进一步提高性能的一种方法是在训练添加的分类器的同时训练(或“微调”)预训练模型顶层的权重。训练过程将强制将权重从通用特征映射调整到与我们的数据集特定关联的特征。
注意:只有在将预先训练的模型设置为不可训练的顶级分类器之后,才能尝试此操作。
如果在预先训练的模型上添加一个随机初始化的分类器并尝试联合训练所有层,则梯度更新的幅度将过大(由于来自分类器的随机权重),并且预先训练的模型将忘记它所学到的内容。
此外应该尝试微调少数顶层,而不是整个MobileNet模型。在大多数卷积网络中,一层越高,它就越专业。前几层学习非常简单和通用的特征,这些特征可以概括为几乎所有类型的图像。当你往上走的时候,这些特性对模型所训练的数据集越来越具体。微调的目标是使这些专门特性适应新数据集,而不是覆盖泛型学习。
base_model.trainable = True # Let's take a look to see how many layers are in the base model print("Number of layers in the base model: ", len(base_model.layers)) # Fine tune from this layer onwards fine_tune_at = 100 # Freeze all the layers before the `fine_tune_at` layer for layer in base_model.layers[:fine_tune_at]: layer.trainable = False model.compile(loss='binary_crossentropy', optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10), metrics=['accuracy']) fine_tune_epochs = 10 total_epochs = initial_epochs + fine_tune_epochs history_fine = model.fit(train_batches, epochs=total_epochs, initial_epoch = history.epoch[-1], validation_data=validation_batches) # 微调后准确率可以达到98%左右 acc += history_fine.history['accuracy'] val_acc += history_fine.history['val_accuracy'] loss += history_fine.history['loss'] val_loss += history_fine.history['val_loss'] plt.figure(figsize=(8, 8)) plt.subplot(2, 1, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.ylim([0.8, 1]) plt.plot([initial_epochs-1,initial_epochs-1], plt.ylim(), label='Start Fine Tuning') plt.legend(loc='lower right') plt.title('Training and Validation Accuracy') plt.subplot(2, 1, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.ylim([0, 1.0]) plt.plot([initial_epochs-1,initial_epochs-1], plt.ylim(), label='Start Fine Tuning') plt.legend(loc='upper right') plt.title('Training and Validation Loss') plt.xlabel('epoch') plt.show()
总结
使用预先训练的模型进行特征提取:
- 在处理小数据集时,通常会利用在同一域中的较大数据集上训练的模型所学习的特征。这是通过实例化预先训练的模型并在上面添加一个完全连接的分类器来完成的。训练前的模型被冻结,训练过程中只更新分类器的权值。在这种情况下,卷积基提取了与每个图像相关联的所有特征,只需训练一个分类器,该分类器根据提取的特征集确定图像类别。
微调预先训练的模型
为了进一步提高性能,可能需要通过微调将预先训练的模型的顶层重新调整到新的数据集。在本例中,我们调整了权重,以便模型学习特定于数据集的高级特性。当训练数据集很大且与训练前模型所用的原始数据集非常相似时,通常建议使用此技术。
相关推荐
- 其实TensorFlow真的很水无非就这30篇熬夜练
-
好的!以下是TensorFlow需要掌握的核心内容,用列表形式呈现,简洁清晰(含表情符号,<300字):1.基础概念与环境TensorFlow架构(计算图、会话->EagerE...
- 交叉验证和超参数调整:如何优化你的机器学习模型
-
准确预测Fitbit的睡眠得分在本文的前两部分中,我获取了Fitbit的睡眠数据并对其进行预处理,将这些数据分为训练集、验证集和测试集,除此之外,我还训练了三种不同的机器学习模型并比较了它们的性能。在...
- 机器学习交叉验证全指南:原理、类型与实战技巧
-
机器学习模型常常需要大量数据,但它们如何与实时新数据协同工作也同样关键。交叉验证是一种通过将数据集分成若干部分、在部分数据上训练模型、在其余数据上测试模型的方法,用来检验模型的表现。这有助于发现过拟合...
- 深度学习中的类别激活热图可视化
-
作者:ValentinaAlto编译:ronghuaiyang导读使用Keras实现图像分类中的激活热图的可视化,帮助更有针对性...
- 超强,必会的机器学习评估指标
-
大侠幸会,在下全网同名[算法金]0基础转AI上岸,多个算法赛Top[日更万日,让更多人享受智能乐趣]构建机器学习模型的关键步骤是检查其性能,这是通过使用验证指标来完成的。选择正确的验证指...
- 机器学习入门教程-第六课:监督学习与非监督学习
-
1.回顾与引入上节课我们谈到了机器学习的一些实战技巧,比如如何处理数据、选择模型以及调整参数。今天,我们将更深入地探讨机器学习的两大类:监督学习和非监督学习。2.监督学习监督学习就像是有老师的教学...
- Python 模型部署不用愁!容器化实战,5 分钟搞定环境配置
-
你是不是也遇到过这种糟心事:花了好几天训练出的Python模型,在自己电脑上跑得顺顺当当,一放到服务器就各种报错。要么是Python版本不对,要么是依赖库冲突,折腾半天还是用不了。别再喊“我...
- 神经网络与传统统计方法的简单对比
-
传统的统计方法如...
- 自回归滞后模型进行多变量时间序列预测
-
下图显示了关于不同类型葡萄酒销量的月度多元时间序列。每种葡萄酒类型都是时间序列中的一个变量。假设要预测其中一个变量。比如,sparklingwine。如何建立一个模型来进行预测呢?一种常见的方...
- 苹果AI策略:慢哲学——科技行业的“长期主义”试金石
-
苹果AI策略的深度原创分析,结合技术伦理、商业逻辑与行业博弈,揭示其“慢哲学”背后的战略智慧:一、反常之举:AI狂潮中的“逆行者”当科技巨头深陷AI军备竞赛,苹果的克制显得格格不入:功能延期:App...
- 时间序列预测全攻略,6大模型代码实操
-
如果你对数据分析感兴趣,希望学习更多的方法论,希望听听经验分享,欢迎移步宝藏公众号...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)