百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

使用TensorFlow创建能够图像重建的自编码器模型

ztj100 2024-11-10 13:13 13 浏览 0 评论



想象你正在解决一个拼图游戏。你已经完成了大部分。假设您需要在一幅几乎完成的图片中间修复一块。你需要从盒子里选择一块,它既适合空间,又能完成整个画面。

我相信你很快就能做到。但是你的大脑是怎么做到的呢?

首先,它会分析空槽周围的图片(在这里你需要固定拼图的一块)。如果图片中有一棵树,你会寻找绿色的部分(这是显而易见的!)所以,简而言之,我们的大脑能够通过知道图像周围的环境来预测图像(它将适合放入槽中)。

在本教程中,我们的模型将执行类似的任务。它将学习图像的上下文,然后利用学习到的上下文预测图像的一部分(缺失的部分)。

在这篇文章之前,我们先看一下代码实现

我建议您在另一个选项卡中打开这个笔记本(TF实现),这样您就可以直观地了解发生了什么。

colab/drive/1zFe9TmMCK2ldUOsVXenvpbNY2FLrLh5k#scrollTo=UXjElGK

问题

我们希望我们的模型能预测图像的一部分。给定一个有部分缺失图像(只有0的图像阵列的一部分),我们的模型将预测原始图像是完整的。

因此,我们的模型将利用它在训练中学习到的上下文重建图像中缺失的部分。


数据

我们将为任务选择一个域。我们选择了一些山地图像,它们是Puneet Bansal在Kaggle上的 Intel Image Classification数据集的一部分。

为什么只有山脉的图像?

在这里,我们选择属于某个特定域的图像。如果我们选择的数据集中有更广泛图像,我们的模型将不能很好地执行。因此,我们将其限制在一个域内。

使用wget下载我在GitHub上托管的数据

!wget https://github.com/shubham0204/Dataset_Archives/blob/master/mountain_images.zip?raw=true -O images.zip 
!unzip images.zip

为了生成训练数据,我们将遍历数据集中的每个图像,并对其执行以下任务,


首先,我们将使用PIL.Image.open()读取图像文件。使用np.asarray()将这个图像对象转换为一个NumPy数组。

确定窗口大小。这是正方形的边长这是从原始图像中得到的。

在[ 0 , image_dim — window_size ]范围内生成2个随机数。image_dim是我们的方形输入图像的大小。

这两个数字(称为px和py)是从原始图像剪裁的位置。选择图像数组的一部分,并将其替换为零数组。

代码如下

x = []
y = []
input_size = ( 228 , 228 , 3 )

# Take out a square region of side 50 px.
window_size = 50

# Store the original images as target images.
for name in os.listdir( 'mountain_images/' ):
    image = Image.open( 'mountain_images/{}'.format( name ) ).resize( input_size[0:2] )
    image = np.asarray( image ).astype( np.uint8 )
    y.append( image )

for name in os.listdir( 'mountain_images/' ):
    image = Image.open( 'mountain_images/{}'.format( name ) ).resize( input_size[0:2] )
    image = np.asarray( image ).astype( np.uint8 )
    # Generate random X and Y coordinates within the image bounds.
    px , py = random.randint( 0 , input_size[0] - window_size ) , random.randint( 0 , input_size[0] - window_size )
    # Take that part of the image and replace it with a zero array. This makes the "missing" part of the image.
    image[ px : px + window_size , py : py + window_size , 0:3 ] = np.zeros( ( window_size , window_size , 3 ) )
    # Append it to an array
    x.append( image )
    
#  Normalize the images
x = np.array( x ) / 255
y = np.array( y ) / 255

# Train test split
x_train, x_test, y_train, y_test = train_test_split( x , y , test_size=0.2 )

自动编码器模型与跳连接

我们添加跳转连接到我们的自动编码器模型。这些跳过连接提供了更好的上采样。通过使用最大池层,许多空间信息会在编码过程中丢失。为了从它的潜在表示(由编码器产生)重建图像,我们添加了跳过连接,它将信息从编码器带到解码器。

alpha = 0.2

inputs = Input( shape=input_size )
conv1 = Conv2D( 32 , kernel_size=( 3 , 3 ) , strides=1 )( inputs )
relu1 = LeakyReLU( alpha )( conv1 )
conv2 = Conv2D( 32 , kernel_size=( 3 , 3 ) , strides=1 )( relu1 )
relu2 = LeakyReLU( alpha )( conv2 )
maxpool1 = MaxPooling2D()( relu2 )

conv3 = Conv2D( 64 , kernel_size=( 3 , 3 ) , strides=1 )( maxpool1 )
relu3 = LeakyReLU( alpha )( conv3 )
conv4 = Conv2D( 64 , kernel_size=( 3 , 3 ) , strides=1 )( relu3 )
relu4 = LeakyReLU( alpha )( conv4 )
maxpool2 = MaxPooling2D()( relu4 )

conv5 = Conv2D( 128 , kernel_size=( 3 , 3 ) , strides=1 )( maxpool2 )
relu5 = LeakyReLU( alpha )( conv5 )
conv6 = Conv2D( 128 , kernel_size=( 3 , 3 ) , strides=1 )( relu5 )
relu6 = LeakyReLU( alpha )( conv6 )
maxpool3 = MaxPooling2D()( relu6 )

conv7 = Conv2D( 256 , kernel_size=( 1 , 1 ) , strides=1 )( maxpool3 )
relu7 = LeakyReLU( alpha )( conv7 )
conv8 = Conv2D( 256 , kernel_size=( 1 , 1 ) , strides=1 )( relu7 )
relu8 = LeakyReLU( alpha )( conv8 )

upsample1 = UpSampling2D()( relu8 )
concat1 = Concatenate()([ upsample1 , conv6 ])
convtranspose1 = Conv2DTranspose( 128 , kernel_size=( 3 , 3 ) , strides=1)( concat1 )
relu9 = LeakyReLU( alpha )( convtranspose1 )
convtranspose2 = Conv2DTranspose( 128 , kernel_size=( 3 , 3 ) , strides=1  )( relu9 )
relu10 = LeakyReLU( alpha )( convtranspose2 )

upsample2 = UpSampling2D()( relu10 )
concat2 = Concatenate()([ upsample2 , conv4 ])
convtranspose3 = Conv2DTranspose( 64 , kernel_size=( 3 , 3 ) , strides=1)( concat2 )
relu11 = LeakyReLU( alpha )( convtranspose3 )
convtranspose4 = Conv2DTranspose( 64 , kernel_size=( 3 , 3 ) , strides=1 )( relu11 )
relu12 = LeakyReLU( alpha )( convtranspose4 )

upsample3 = UpSampling2D()( relu12 )
concat3 = Concatenate()([ upsample3 , conv2 ])
convtranspose5 = Conv2DTranspose( 32 , kernel_size=( 3 , 3 ) , strides=1)( concat3 )
relu13 = LeakyReLU( alpha )( convtranspose5 )
convtranspose6 = Conv2DTranspose( 3 , kernel_size=( 3 , 3 ) , strides=1 , activation='relu' )( relu13 )

model = tf.keras.models.Model( inputs , convtranspose6 )
model.compile( loss='mse' , optimizer='adam' , metrics=[ 'mse' ] )

最后,训练我们的自动编码器模型,

model.fit( x_train , y_train , epochs=150 , batch_size=25 , validation_data=( x_test , y_test ) )


结论

以上结果是在少数测试图像上得到的。我们观察到模型几乎已经学会了如何填充黑盒!但我们仍然可以分辨出盒子在原始图像中的位置。这样,我们就可以建立一个模型来预测图像缺失的部分。

这里我们只是用了一个简单的模型来作为样例,如果我们要推广到现实生活中,就需要使用更大的数据集和更深的网络,例如可以使用现有的sota模型,加上imagenet的图片进行训练。


作者 Shubham Panchal

deephub 翻译组

相关推荐

Vue 技术栈(全家桶)(vue technology)

Vue技术栈(全家桶)尚硅谷前端研究院第1章:Vue核心Vue简介官网英文官网:https://vuejs.org/中文官网:https://cn.vuejs.org/...

vue 基础- nextTick 的使用场景(vue的nexttick这个方法有什么用)

前言《vue基础》系列是再次回炉vue记的笔记,除了官网那部分知识点外,还会加入自己的一些理解。(里面会有部分和官网相同的文案,有经验的同学择感兴趣的阅读)在开发时,是不是遇到过这样的场景,响应...

vue3 组件初始化流程(vue组件初始化顺序)

学习完成响应式系统后,咋们来看看vue3组件的初始化流程既然是看vue组件的初始化流程,咋们先来创建基本的代码,跑跑流程(在app.vue中写入以下内容,来跑流程)...

vue3优雅的设置element-plus的table自动滚动到底部

场景我是需要在table最后添加一行数据,然后把滚动条滚动到最后。查网上的解决方案都是读取html结构,暴力的去获取,虽能解决问题,但是不喜欢这种打补丁的解决方案,我想着官方应该有相关的定义,于是就去...

Vue3为什么推荐使用ref而不是reactive

为什么推荐使用ref而不是reactivereactive本身具有很大局限性导致使用过程需要额外注意,如果忽视这些问题将对开发造成不小的麻烦;ref更像是vue2时代optionapi的data的替...

9、echarts 在 vue 中怎么引用?(必会)

首先我们初始化一个vue项目,执行vueinitwebpackechart,接着我们进入初始化的项目下。安装echarts,npminstallecharts-S//或...

无所不能,将 Vue 渲染到嵌入式液晶屏

该文章转载自公众号@前端时刻,https://mp.weixin.qq.com/s/WDHW36zhfNFVFVv4jO2vrA前言...

vue-element-admin 增删改查(五)(vue-element-admin怎么用)

此篇幅比较长,涉及到的小知识点也比较多,一定要耐心看完,记住学东西没有耐心可不行!!!一、添加和修改注:添加和编辑用到了同一个组件,也就是此篇文章你能学会如何封装组件及引用组件;第二能学会async和...

最全的 Vue 面试题+详解答案(vue面试题知识点大全)

前言本文整理了...

基于 vue3.0 桌面端朋友圈/登录验证+60s倒计时

今天给大家分享的是Vue3聊天实例中的朋友圈的实现及登录验证和倒计时操作。先上效果图这个是最新开发的vue3.x网页端聊天项目中的朋友圈模块。用到了ElementPlus...

不来看看这些 VUE 的生命周期钩子函数?| 原力计划

作者|huangfuyk责编|王晓曼出品|CSDN博客VUE的生命周期钩子函数:就是指在一个组件从创建到销毁的过程自动执行的函数,包含组件的变化。可以分为:创建、挂载、更新、销毁四个模块...

Vue3.5正式上线,父传子props用法更丝滑简洁

前言Vue3.5在2024-09-03正式上线,目前在Vue官网显最新版本已经是Vue3.5,其中主要包含了几个小改动,我留意到日常最常用的改动就是props了,肯定是用Vue3的人必用的,所以针对性...

Vue 3 生命周期完整指南(vue生命周期及使用)

Vue2和Vue3中的生命周期钩子的工作方式非常相似,我们仍然可以访问相同的钩子,也希望将它们能用于相同的场景。...

救命!这 10 个 Vue3 技巧藏太深了!性能翻倍 + 摸鱼神器全揭秘

前端打工人集合!是不是经常遇到这些崩溃瞬间:Vue3项目越写越卡,组件通信像走迷宫,复杂逻辑写得脑壳疼?别慌!作为在一线摸爬滚打多年的老前端,今天直接甩出10个超实用的Vue3实战技巧,手把...

怎么在 vue 中使用 form 清除校验状态?

在Vue中使用表单验证时,经常需要清除表单的校验状态。下面我将介绍一些方法来清除表单的校验状态。1.使用this.$refs...

取消回复欢迎 发表评论: