深度学习中的类别激活热图可视化
ztj100 2025-08-07 00:06 5 浏览 0 评论
作者:Valentina Alto
编译:ronghuaiyang
导读
使用Keras实现图像分类中的激活热图的可视化,帮助更有针对性的改进模型。
类别激活图(CAM)是一种用于计算机视觉分类任务的强大技术。它允许研究人员检查被分类的图像,并了解图像的哪些部分/像素对模型的最终输出有更大的贡献。
基本上,假设我们构建一个CNN,目标是将人的照片分类为“男人”和“女人”,然后我们给它提供一个新照片,它返回标签“男人”。有了CAM工具,我们就能看到图片的哪一部分最能激活“Man”类。如果我们想提高模型的准确性,必须了解需要修改哪些层,或者我们是否想用不同的方式预处理训练集图像,这将非常有用。
在本文中,我将向你展示这个过程背后的思想。为了达到这个目的,我会使用一个在ImageNet上预训练好的CNN, Resnet50。
我在这个实验中要用到的图像是,这只金毛猎犬:
首先,让我们在这张图上尝试一下我们预训练模型,让它返回三个最有可能的类别:
from keras.applications.resnet50 import ResNet50
from keras.preprocessing import image
from keras.applications.resnet50 import preprocess_input, decode_predictions
import numpy as npmodel = ResNet50(weights='imagenet')img_path = 'golden.jpg'
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)preds = model.predict(x)
# decode the results into a list of tuples (class, description, probability)
print('Predicted:', decode_predictions(preds, top=3)[0])
如你所见,第一个结果恰好返回了我们正在寻找的类别:Golden retriver。
现在我们的目标是识别出我们的照片中最能激活黄金标签的部分。为此,我们将使用一种称为“梯度加权类别激活映射(Grad-CAM)”的技术(官方论文:https://arxiv.org/abs/1610.02391)。
这个想法是这样的:想象我们有一个训练好的CNN,我们给它提供一个新的图像。它将为该图像返回一个类。然后,如果我们取最后一个卷积层的输出特征图,并根据输出类别对每个通道的梯度对每个通道加权,我们就得到了一个热图,它表明了输入图像中哪些部分对该类别激活程度最大。
让我们看看使用Keras的实现。首先,让我们检查一下我们预先训练过的ResNet50的结构,以确定我们想要检查哪个层。由于网络结构很长,我将在这里只显示最后的block:
from keras.utils import plot_model
plot_model(model)
让我们使用最后一个激活层activation_49来提取我们的feature map。
golden = model.output[:, np.argmax(preds[0])]
last_conv_layer = model.get_layer('activation_49')
from keras import backend as K
grads = K.gradients(golden, last_conv_layer.output)[0]
pooled_grads = K.mean(grads, axis=(0, 1, 2))
iterate = K.function([model.input], [pooled_grads, last_conv_layer.output[0]])
pooled_grads_value, conv_layer_output_value = iterate([x])
for i in range(pooled_grads.shape[0]):
conv_layer_output_value[:, :, i] *= pooled_grads_value[i]
heatmap = np.mean(conv_layer_output_value, axis=-1)
import matplotlib.pyplot as plt
heatmap = np.maximum(heatmap, 0)
heatmap /= np.max(heatmap)
plt.matshow(heatmap)
这个热图上看不出什么东西出来。因此,我们将该热图与输入图像合并如下:
import cv2
img = cv2.imread(img_path)
heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
heatmap = np.uint8(255 * heatmap)
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
merged= heatmap * 0.4 + imgplt.imshow(merged)
如你所见,图像的某些部分(如鼻子部分)特别的指示出了输入图像的类别。
英文原文:https://valentinaalto.medium.com/class-activation-maps-in-deep-learning-14101e2ec7e1
更多内容,请关注微信公众号“AI公园”。
- 上一篇:超强,必会的机器学习评估指标
- 下一篇:机器学习交叉验证全指南:原理、类型与实战技巧
相关推荐
- 其实TensorFlow真的很水无非就这30篇熬夜练
-
好的!以下是TensorFlow需要掌握的核心内容,用列表形式呈现,简洁清晰(含表情符号,<300字):1.基础概念与环境TensorFlow架构(计算图、会话->EagerE...
- 交叉验证和超参数调整:如何优化你的机器学习模型
-
准确预测Fitbit的睡眠得分在本文的前两部分中,我获取了Fitbit的睡眠数据并对其进行预处理,将这些数据分为训练集、验证集和测试集,除此之外,我还训练了三种不同的机器学习模型并比较了它们的性能。在...
- 机器学习交叉验证全指南:原理、类型与实战技巧
-
机器学习模型常常需要大量数据,但它们如何与实时新数据协同工作也同样关键。交叉验证是一种通过将数据集分成若干部分、在部分数据上训练模型、在其余数据上测试模型的方法,用来检验模型的表现。这有助于发现过拟合...
- 深度学习中的类别激活热图可视化
-
作者:ValentinaAlto编译:ronghuaiyang导读使用Keras实现图像分类中的激活热图的可视化,帮助更有针对性...
- 超强,必会的机器学习评估指标
-
大侠幸会,在下全网同名[算法金]0基础转AI上岸,多个算法赛Top[日更万日,让更多人享受智能乐趣]构建机器学习模型的关键步骤是检查其性能,这是通过使用验证指标来完成的。选择正确的验证指...
- 机器学习入门教程-第六课:监督学习与非监督学习
-
1.回顾与引入上节课我们谈到了机器学习的一些实战技巧,比如如何处理数据、选择模型以及调整参数。今天,我们将更深入地探讨机器学习的两大类:监督学习和非监督学习。2.监督学习监督学习就像是有老师的教学...
- Python 模型部署不用愁!容器化实战,5 分钟搞定环境配置
-
你是不是也遇到过这种糟心事:花了好几天训练出的Python模型,在自己电脑上跑得顺顺当当,一放到服务器就各种报错。要么是Python版本不对,要么是依赖库冲突,折腾半天还是用不了。别再喊“我...
- 神经网络与传统统计方法的简单对比
-
传统的统计方法如...
- 自回归滞后模型进行多变量时间序列预测
-
下图显示了关于不同类型葡萄酒销量的月度多元时间序列。每种葡萄酒类型都是时间序列中的一个变量。假设要预测其中一个变量。比如,sparklingwine。如何建立一个模型来进行预测呢?一种常见的方...
- 苹果AI策略:慢哲学——科技行业的“长期主义”试金石
-
苹果AI策略的深度原创分析,结合技术伦理、商业逻辑与行业博弈,揭示其“慢哲学”背后的战略智慧:一、反常之举:AI狂潮中的“逆行者”当科技巨头深陷AI军备竞赛,苹果的克制显得格格不入:功能延期:App...
- 时间序列预测全攻略,6大模型代码实操
-
如果你对数据分析感兴趣,希望学习更多的方法论,希望听听经验分享,欢迎移步宝藏公众号...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)