孪生网络如何识别面部相似度?有这篇PyTorch实例教程就够啦!
ztj100 2024-12-19 17:55 37 浏览 0 评论
原文来源:hackernoon
作者:Harshvardhan Gupta
「机器人圈」编译:嗯~阿童木呀
这是两篇文章的第二篇,为了更系统地了解与掌握该教程,在阅读本文之前,建议最好通读机器人圈的前一篇文章。
在前一篇文章中,我们讨论了小样本数据旨在解决的主要问题类别,以及孪生网络之所以能够成为解决这个问题优良选择的原因。首先,我们来重温一个特殊的损失函数,它能够在数据对中计算两个的图像相似度。我们现在将在PyTorch中实施我们之前所讨论的全部内容。
你可以在本文末尾查看完整的代码链接
架构
我们将使用的是标准卷积神经网络(CNN)架构,在每个卷积层之后使用批量归一化,然后dropout。
代码片段:孪生网络架构:
class SiameseNetwork(nn.Module):
def __init__(self):
super(SiameseNetwork, self).__init__()
self.cnn1 = nn.Sequential(
nn.ReflectionPad2d(1),
nn.Conv2d(1, 4, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(4),
nn.Dropout2d(p=.2),
nn.ReflectionPad2d(1),
nn.Conv2d(4, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
nn.Dropout2d(p=.2),
nn.ReflectionPad2d(1),
nn.Conv2d(8, 8, kernel_size=3),
nn.ReLU(inplace=True),
nn.BatchNorm2d(8),
nn.Dropout2d(p=.2),
)
self.fc1 = nn.Sequential(
nn.Linear(8*100*100, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 500),
nn.ReLU(inplace=True),
nn.Linear(500, 5)
)
def forward_once(self, x):
output = self.cnn1(x)
output = output.view(output.size()[0], -1)
output = self.fc1(output)
return output
def forward(self, input1, input2):
output1 = self.forward_once(input1)
output2 = self.forward_once(input2)
return output1, output2
其实这个网络并没有什么特别之处,它可以接收一个100px * 100px的输入,并且在卷积层之后具有3个完全连接的层。
但是其他孪生网络如何呢?
在上篇文章中,我展示了一对网络是如何处理数据对中的每个图像的。但在这篇文章中,只有一个网络。因为两个网络的权重是相同的,所以我们使用一个模型并连续地给它提供两个图像。之后,我们使用两个图像来计算损失值,然后再返回传播。这样可以节省大量的内存,绝对不会影响其他指标(如精确度)。
对比损失
我们将对比损失定义为
等式1.0
我们将Dw(也就欧氏距离)定义为:
等式1.1
Gw是我们网络的一个图像的输出。
PyTorch中的对比损失看起来是这样的:
代码片段:默认边际价值为2的对比损失:
class ContrastiveLoss(torch.nn.Module):
"""
Contrastive loss function.
Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
"""
def __init__(self, margin=2.0):
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self, output1, output2, label):
euclidean_distance = F.pairwise_distance(output1, output2)
loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
(label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
return loss_contrastive
数据集
在上一篇文章中,我想使用MNIST,但有些读者建议我使用我在同篇文章中所讨论的面部相似性样本。因此,我决定从MNIST / OmniGlot切换到AT&T面部数据集。
数据集包含40名测试对象的不同角度的图像。我从训练中挑选出3名测试对象的图像,以测试我们的模型。
不同类的样本图像
一名测试对象的所有样本图像
数据加载
我们的架构需要一个输入对,以及标签(类似/不相似)。因此,我创建了自己的自定义数据加载器来完成这项工作。它使用图像文件夹从文件夹中读取图像。这意味着你可以将其用于任何你所希望的数据集中。
代码段:该数据集生成一对图像和相似性标签。如果图像来自同一个类,标签将为0,否则为1:
class SiameseNetworkDataset(Dataset):
def __init__(self,imageFolderDataset,transform=None,should_invert=True):
self.imageFolderDataset = imageFolderDataset
self.transform = transform
self.should_invert = should_invert
def __getitem__(self,index):
img0_tuple = random.choice(self.imageFolderDataset.imgs)
#we need to make sure approx 50% of images are in the same class
should_get_same_class = random.randint(0,1)
if should_get_same_class:
while True:
#keep looping till the same class image is found
img1_tuple = random.choice(self.imageFolderDataset.imgs)
if img0_tuple[1]==img1_tuple[1]:
break
else:
img1_tuple = random.choice(self.imageFolderDataset.imgs)
img0 = Image.open(img0_tuple[0])
img1 = Image.open(img1_tuple[0])
img0 = img0.convert("L")
img1 = img1.convert("L")
if self.should_invert:
img0 = PIL.ImageOps.invert(img0)
img1 = PIL.ImageOps.invert(img1)
if self.transform is not None:
img0 = self.transform(img0)
img1 = self.transform(img1)
return img0, img1 , torch.from_numpy(np.array([int(img1_tuple[1]!=img0_tuple[1])],dtype=np.float32))
def __len__(self):
return len(self.imageFolderDataset.imgs)
孪生网络数据集生成一对图像,以及它们的相似性标签(如果为真,则为0,否则为1)。 为了防止不平衡,我将保证几乎一半的图像是同一个类的,而另一半则不是。
训练孪生网络
孪生网络的训练过程如下:
1.通过网络传递图像对的第一张图像。
2.通过网络传递图像对的第二张图像。
3.使用1和2中的输出来计算损失。
4.返回传播损失计算梯度
5.使用优化器更新权重。我们将用Adam来进行演示:
代码片段:训练孪生网络:
net = SiameseNetwork().cuda()
criterion = ContrastiveLoss()
optimizer = optim.Adam(net.parameters(),lr = 0.0005 )
counter = []
loss_history = []
iteration_number= 0
for epoch in range(0,Config.train_number_epochs):
for i, data in enumerate(train_dataloader,0):
img0, img1 , label = data
img0, img1 , label = Variable(img0).cuda(), Variable(img1).cuda() , Variable(label).cuda()
output1,output2 = net(img0,img1)
optimizer.zero_grad()
loss_contrastive = criterion(output1,output2,label)
loss_contrastive.backward()
optimizer.step()
if i %10 == 0 :
print("Epoch number {}\n Current loss {}\n".format(epoch,loss_contrastive.data[0]))
iteration_number +=10
counter.append(iteration_number)
loss_history.append(loss_contrastive.data[0])
show_plot(counter,loss_history)
网络使用Adam,以0.0005的学习率进行了100次的迭代训练。随着时间的变化,损失曲线如下所示:
随时间变化的损失值曲线,x轴代表迭代次数
测试网络
我们从训练中抽取了3个测试对象,用于评估我们的模型性能。
为了计算相似度,我们只是计算Dw(公式1.1)。其中距离直接对应于图像对之间的不相似度。Dw的高值表示较高的不相似度。
代码片段:通过计算模型输出之间的距离来评估模型:
folder_dataset_test = dset.ImageFolder(root=Config.testing_dir)
siamese_dataset = SiameseNetworkDataset(imageFolderDataset=folder_dataset_test,
transform=transforms.Compose([transforms.Scale((100,100)),
transforms.ToTensor()
])
,should_invert=False)
test_dataloader = DataLoader(siamese_dataset,num_workers=6,batch_size=1,shuffle=True)
dataiter = iter(test_dataloader)
x0,_,_ = next(dataiter)
for i in range(10):
_,x1,label2 = next(dataiter)
concatenated = torch.cat((x0,x1),0)
output1,output2 = net(Variable(x0).cuda(),Variable(x1).cuda())
euclidean_distance = F.pairwise_distance(output1, output2)
imshow(torchvision.utils.make_grid(concatenated),'Dissimilarity: {:.2f}'.format(euclidean_distance.cpu().data.numpy()[0][0]))
模型的一些输出,较低的值表示相似度较高,较高的值表示相似度较低。
果然,实验结果相当不错。网络能够做到从图像中分辨出同一个人,即使这些图像是从不同角度拍摄的。与此同时,该网络在辨别不同的图像方面也做得非常好。
结论
我们讨论并实施了用一个孪生网络从面部图像对中进行面部识别的实验。当某个特定面部的训练样本很少(或只有一个)时,使用这个办法是非常有效的。我们使用一个有辨别性的损失函数来训练神经网络。
相关推荐
- SpringBoot整合SpringSecurity+JWT
-
作者|Sans_https://juejin.im/post/5da82f066fb9a04e2a73daec一.说明SpringSecurity是一个用于Java企业级应用程序的安全框架,主要包含...
- 「计算机毕设」一个精美的JAVA博客系统源码分享
-
前言大家好,我是程序员it分享师,今天给大家带来一个精美的博客系统源码!可以自己买一个便宜的云服务器,当自己的博客网站,记录一下自己学习的心得。开发技术博客系统源码基于SpringBoot,shiro...
- springboot教务管理系统+微信小程序云开发附带源码
-
今天给大家分享的程序是基于springboot的管理,前端是小程序,系统非常的nice,不管是学习还是毕设都非常的靠谱。本系统主要分为pc端后台管理和微信小程序端,pc端有三个角色:管理员、学生、教师...
- SpringBoot+LayUI后台管理系统开发脚手架
-
源码获取方式:关注,转发之后私信回复【源码】即可免费获取到!项目简介本项目本着避免重复造轮子的原则,建立一套快速开发JavaWEB项目(springboot-mini),能满足大部分后台管理系统基础开...
- Spring Boot的Security安全控制——认识SpringSecurity!
-
SpringBoot的Security安全控制在Web项目开发中,安全控制是非常重要的,不同的人配置不同的权限,这样的系统才安全。最常见的权限框架有Shiro和SpringSecurity。Shi...
- 前同事2024年接私活已入百万,都是用这几个开源的SpringBoot项目
-
前言不得不佩服SpringBoot的生态如此强大,今天给大家推荐几款优秀的后台管理系统,小伙伴们再也不用从头到尾撸一个项目了。SmartAdmin...
- 值得学习的15 个优秀开源的 Spring Boot 学习项目
-
SpringBoot算是目前Java领域最火的技术栈了,除了书呢?当然就是开源项目了,今天整理15个开源领域非常不错的SpringBoot项目供大家学习,参考。高富帅的路上只能帮你到这里了,...
- 开发企业官网就用这个基于SpringBoot的CMS系统,真香
-
前言推荐这个项目是因为使用手册部署手册非常...
- 2021年超详细的java学习路线总结—纯干货分享
-
本文整理了java开发的学习路线和相关的学习资源,非常适合零基础入门java的同学,希望大家在学习的时候,能够节省时间。纯干货,良心推荐!第一阶段:Java基础...
- jeecg-boot学习总结及使用心得(jeecgboot简单吗)
-
jeecg-boot学习总结及使用心得1.jeecg-boot是一个真正前后端分离的模版项目,便于二次开发,使用的都是较流行的新技术,后端技术主要有spring-boot2.x、shiro、Myb...
- 后勤集团原料管理系统springboot+Layui+MybatisPlus+Shiro源代码
-
本项目为前几天收费帮学妹做的一个项目,JavaEEJSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。一、项目描述后勤集团原料管理系统spr...
- 白卷开源SpringBoot+Vue的前后端分离入门项目
-
简介白卷是一个简单的前后端分离项目,主要采用Vue.js+SpringBoot技术栈开发。除了用作入门练习,作者还希望该项目可以作为一些常见Web项目的脚手架,帮助大家简化搭建网站的流程。...
- Spring Security 自动踢掉前一个登录用户,一个配置搞定
-
登录成功后,自动踢掉前一个登录用户,松哥第一次见到这个功能,就是在扣扣里边见到的,当时觉得挺好玩的。自己做开发后,也遇到过一模一样的需求,正好最近的SpringSecurity系列正在连载,就结...
- 收藏起来!这款开源在线考试系统,我爱了
-
大家好,我是为广大程序员兄弟操碎了心的小编,每天推荐一个小工具/源码,装满你的收藏夹,每天分享一个小技巧,让你轻松节省开发效率,实现不加班不熬夜不掉头发,是我的目标!今天小编推荐一款基于Spr...
- Shiro框架:认证和授权原理(shiro权限认证流程)
-
优质文章,及时送达前言Shiro作为解决权限问题的常用框架,常用于解决认证、授权、加密、会话管理等场景。本文将对Shiro的认证和授权原理进行介绍:Shiro可以做什么?、Shiro是由什么组成的?举...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- SpringBoot整合SpringSecurity+JWT
- 「计算机毕设」一个精美的JAVA博客系统源码分享
- springboot教务管理系统+微信小程序云开发附带源码
- SpringBoot+LayUI后台管理系统开发脚手架
- Spring Boot的Security安全控制——认识SpringSecurity!
- 前同事2024年接私活已入百万,都是用这几个开源的SpringBoot项目
- 值得学习的15 个优秀开源的 Spring Boot 学习项目
- 开发企业官网就用这个基于SpringBoot的CMS系统,真香
- 2021年超详细的java学习路线总结—纯干货分享
- jeecg-boot学习总结及使用心得(jeecgboot简单吗)
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)