无需数学背景!谷歌研究员为你解密生成式对抗网络
ztj100 2024-12-12 16:14 11 浏览 0 评论
原文来源:Towards Data Science
作者:Stefan Hosein
「雷克世界」编译:嗯~是阿童木呀、KABUDA
导语:大家都知道,自从生成式对抗网络(GAN)出现以来,便在图像处理方面有着广泛的应用。但还是有很多人对于GAN不是很了解,担心由于没有数学知识底蕴而学不会GAN。在本文中,谷歌研究员Stefan Hosein提供了一份初学者入门GAN的教程,在这份教程中,即使你没有拥有深厚的数学知识,你也能够了解什么是生成式对抗网络(GAN)。并且在阅读完这篇文章之后,你将学会如何编写一个可以创建数字的简单GAN!
类比
理解GAN的一个最为简单的方法是通过一个简单的比喻:
假设有一家商店,店主要从顾客那里购买某些种类的葡萄酒,然后再将这些葡萄酒销售出去。
然而,有些可恶的顾客为了赚取金钱而出售假酒。在这种情况下,店主必须能够区分假酒和正宗的葡萄酒。
你可以想象,在最初的时候,伪造者在试图出售假酒时可能会犯很多错误,并且店主很容易就会发现该酒不是正宗的葡萄酒。经历过这些失败之后,伪造者会继续尝试使用不同的技术来模拟真正的葡萄酒,而有些方法最终会取得成功。现在,伪造者知道某些技术已经能够躲过店主的检查,那么他就可以开始进一步对基于这些技术的假酒进行改善提升。
与此同时,店主可能会从其他店主或葡萄酒专家那里得到一些反馈,说明她所拥有的一些葡萄酒并不是原装的。这意味着店主必须改进她的判别方式,从而确定葡萄酒是伪造的还是正宗的。伪造者的目标是制造出与正宗葡萄酒无法区分的葡萄酒,而店主的目标是准确地分辨葡萄酒是否是正宗的。
可以这样说,这种循环往复的竞争正是GAN背后的主要思想。
生成式对抗网络的组成部分
通过上面的例子,我们可以提出一个GAN的体系结构。
GAN中有两个主要的组成部分:生成器和鉴别器。在上面我们所描述的例子中,店主被称为鉴别器网络,通常是一个卷积神经网络(因为GAN主要用于图像任务),主要是分配图像是真实的概率。
伪造者被称为生成式网络,并且通常也是一个卷积神经网络(具有解卷积层,deconvolution layers)。该网络接收一些噪声向量并输出一个图像。当对生成式网络进行训练时,它会学习可以对图像的哪些区域进行改进/更改,以便鉴别器将难以将其生成的图像与真实图像区分开来。
生成式网络不断地生成与真实图像更为接近的图像,而与此同时,鉴别式网络则试图确定真实图像和假图像之间的差异。最终的目标就是建立一个生成式网络,它可以生成与真实图像无法区分的图像。
用Keras编写一个简单的生成式对抗网络
现在,你已经了解什么是GAN,以及它们的主要组成部分,那么现在我们可以开始试着编写一个非常简单的代码。你可以使用Keras,如果你不熟悉这个Python库的话,则应在继续进行操作之前阅读本教程。本教程基于易于理解的GAN进行开发的。
首先,你需要做的第一件事是通过pip安装以下软件包:
- keras
- matplotlib
- tensorflow
- tqdm
你将使用matplotlib绘图,tensorflow作为Keras后端库和tqdm,以显示每个轮数(迭代)的花式进度条。
下一步是创建一个Python脚本,在这个脚本中,你首先需要导入你将要使用的所有模块和函数。在使用它们时将给出每个解释。
import
os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from keras.layers import Input
from keras.models import Model, Sequential
from keras.layers.core import Dense, Dropout
from keras.layers.advanced_activations import LeakyReLU
from keras.datasets import mnist
from keras.optimizers import Adam
from keras import initializers
你现在需要设置一些变量:
# Let Keras know that we are using tensorflow as our backend engine
os.environ["KERAS_BACKEND"] = "tensorflow"
# To make sure that we can reproduce the experiment and get the same results
np.random.seed(10)
# The dimension of our random noise vector.
random_dim = 100
在开始构建鉴别器和生成器之前,你首先应该收集数据,并对其进行预处理。你将会使用到常见的MNIST数据集,该数据集具有一组从0到9的单个数字图像。
MINST数字样本
def load_minst_data():
# load the data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalize our inputs to be in the range[-1, 1]
x_train = (x_train.astype(np.float32) - 127.5)/127.5
# convert x_train with a shape of (60000, 28, 28) to (60000, 784) so we have
# 784 columns per row
x_train = x_train.reshape(60000, 784)
return (x_train, y_train, x_test, y_test)
需要注意的是,mnist.load_data()是Keras的一部分,这使得你可以轻松地将MNIST数据集导入至工作区域中。
现在,你可以开始创建你的生成器和鉴别器网络了。在这一过程中,你会使用到Adam优化器(https://machinelearningmastery.com/adam-optimization-algorithm-for-deep-learning/)。此外,你还需要创建一个带有三个隐藏层的神经网络,其激活函数为Leaky Relu(https://www.quora.com/What-are-the-advantages-of-using-Leaky-Rectified-Linear-Units-Leaky-ReLU-over-normal-ReLU-in-deep-learning)。对于鉴别器而言,你需要为其添加dropout层(dropout layers)(https://medium.com/@amarbudhiraja/https-medium-com-amarbudhiraja-learning-less-to-learn-better-dropout-in-deep-machine-learning-74334da4bfc5),以提高对未知图像的鲁棒性。
def get_optimizer():
return Adam(lr=0.0002, beta_1=0.5)
def get_generator(optimizer):
generator = Sequential()
generator.add(Dense(256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=optimizer)
return generator
def get_discriminator(optimizer):
discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
return discriminator
接下来,则需要将发生器和鉴别器组合在一起!
def get_gan_network(discriminator, random_dim, generator, optimizer):
# We initially set trainable to False since we only want to train either the
# generator or discriminator at a time
discriminator.trainable = False
# gan input (noise) will be 100-dimensional vectors
gan_input = Input(shape=(random_dim,))
# the output of the generator (an image)
x = generator(gan_input)
# get the output of the discriminator (probability if the image is real or not)
gan_output = discriminator(x)
gan = Model(inputs=gan_input, outputs=gan_output)
gan.compile(loss='binary_crossentropy', optimizer=optimizer)
return gan
为了完整起见,你还可以创建一个函数,使其每训练20个轮数就对生成的图像进行1次保存。由于这不是本次课程的核心内容,因此你不必完全理解该函数。
def plot_generated_images(epoch, generator, examples=100, dim=(10, 10), figsize=(10, 10)):
noise = np.random.normal(0, 1, size=[examples, random_dim])
generated_images = generator.predict(noise)
generated_images = generated_images.reshape(examples, 28, 28)
plt.figure(figsize=figsize)
for i in range(generated_images.shape[0]):
plt.subplot(dim[0], dim[1], i+1)
plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
plt.axis('off')
plt.tight_layout()
plt.savefig('gan_generated_image_epoch_%d.png' % epoch)
你现在已经编码了大部分网络,剩下的就是训练这个网络,并查看你创建的图像。
def train(epochs=1, batch_size=128):
# Get the training and testing data
x_train, y_train, x_test, y_test = load_minst_data()
# Split the training data into batches of size 128
batch_count = x_train.shape[0] / batch_size
# Build our GAN netowrk
adam = get_optimizer()
generator = get_generator(adam)
discriminator = get_discriminator(adam)
gan = get_gan_network(discriminator, random_dim, generator, adam)
for e in xrange(1, epochs+1):
print '-'*15, 'Epoch %d' % e, '-'*15
for _ in tqdm(xrange(batch_count)):
# Get a random set of input noise and images
noise = np.random.normal(0, 1, size=[batch_size, random_dim])
image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]
# Generate fake MNIST images
generated_images = generator.predict(noise)
X = np.concatenate([image_batch, generated_images])
# Labels for generated and real data
y_dis = np.zeros(2*batch_size)
# One-sided label smoothing
y_dis[:batch_size] = 0.9
# Train discriminator
discriminator.trainable = True
discriminator.train_on_batch(X, y_dis)
# Train generator
noise = np.random.normal(0, 1, size=[batch_size, random_dim])
y_gen = np.ones(batch_size)
discriminator.trainable = False
gan.train_on_batch(noise, y_gen)
if e == 1 or e % 20 == 0:
plot_generated_images(e, generator)
if __name__ == '__main__':
train(400, 128)
在训练400个轮数后,你可以查看生成的图像。在查看经过1个轮数训练后而生成的图像时,你会发现它没有任何真实结构,在查看经过40个轮数训练后而生成的图像时,你会发现数字开始成形,最后,在查看经过400个轮数训练后而生成的图像时,你会发现,除了一组数字难以辨识外,其余大多数数字都清晰可见。
训练1个轮数后的结果(上)| 训练40个轮数后的结果(中) | 训练400个轮数后的结果(下)
此代码在CPU上运行一次大约需要2分钟,这也是我们选择该代码的主要原因。你可以尝试进行更多轮数的训练,并向生成器和鉴别器中添加更多数量(种类)的层。当然,在仅使用CPU的前提下,采用更复杂和更深层的体系结构时,相应的代码运行时间也会有所延长。但也不要因此放弃尝试。
至此,你已经完成了全部的学习任务,你以一种直观的方式学习了生成式对抗网络(GAN)的基础知识!并且,你还在Keras库的协助下实现了你的第一个模型。
原文链接:https://towardsdatascience.com/demystifying-generative-adversarial-networks-c076d8db8f44
相关推荐
- 使用Python编写Ping监测程序(python 测验)
-
Ping是一种常用的网络诊断工具,它可以测试两台计算机之间的连通性;如果您需要监测某个IP地址的连通情况,可以使用Python编写一个Ping监测程序;本文将介绍如何使用Python编写Ping监测程...
- 批量ping!有了这个小工具,python再也香不了一点
-
号主:老杨丨11年资深网络工程师,更多网工提升干货,请关注公众号:网络工程师俱乐部下午好,我的网工朋友。在咱们网工的日常工作中,经常需要检测多个IP地址的连通性。不知道你是否也有这样的经历:对着电脑屏...
- python之ping主机(python获取ping结果)
-
#coding=utf-8frompythonpingimportpingforiinrange(100,255):ip='192.168.1.'+...
- 网站安全提速秘籍!Nginx配置HTTPS+反向代理实战指南
-
太好了,你直接问到重点场景了:Nginx+HTTPS+反向代理,这个组合是现代Web架构中最常见的一种部署方式。咱们就从理论原理→实操配置→常见问题排查→高级玩法一层层剖开说,...
- Vue开发中使用iframe(vue 使用iframe)
-
内容:iframe全屏显示...
- Vue3项目实践-第五篇(改造登录页-Axios模拟请求数据)
-
本文将介绍以下内容:项目中的public目录和访问静态资源文件的方法使用json文件代替http模拟请求使用Axios直接访问json文件改造登录页,配合Axios进行登录请求,并...
- Vue基础四——Vue-router配置子路由
-
我们上节课初步了解Vue-router的初步知识,也学会了基本的跳转,那我们这节课学习一下子菜单的路由方式,也叫子路由。子路由的情况一般用在一个页面有他的基础模版,然后它下面的页面都隶属于这个模版,只...
- Vue3.0权限管理实现流程【实践】(vue权限管理系统教程)
-
作者:lxcan转发链接:https://segmentfault.com/a/1190000022431839一、整体思路...
- swiper在vue中正确的使用方法(vue中如何使用swiper)
-
swiper是网页中非常强大的一款轮播插件,说是轮播插件都不恰当,因为它能做的事情太多了,swiper在vue下也是能用的,需要依赖专门的vue-swiper插件,因为vue是没有操作dom的逻辑的,...
- Vue怎么实现权限管理?控制到按钮级别的权限怎么做?
-
在Vue项目中实现权限管理,尤其是控制到按钮级别的权限控制,通常包括以下几个方面:一、权限管理的层级划分...
- 【Vue3】保姆级毫无废话的进阶到实战教程 - 01
-
作为一个React、Vue双修选手,在Vue3逐渐稳定下来之后,是时候摸摸Vue3了。Vue3的变化不可谓不大,所以,本系列主要通过对Vue3中的一些BigChanges做...
- Vue3开发极简入门(13):编程式导航路由
-
前面几节文章,写的都是配置路由。但是在实际项目中,下面这种路由导航的写法才是最常用的:比如登录页面,服务端校验成功后,跳转至系统功能页面;通过浏览器输入URL直接进入系统功能页面后,读取本地存储的To...
- vue路由同页面重定向(vue路由重定向到外部url)
-
在Vue中,可以使用路由的重定向功能来实现同页面的重定向。首先,在路由配置文件(通常是`router/index.js`)中,定义一个新的路由,用于重定向到同一个页面。例如,我们可以定义一个名为`Re...
- 那个 Vue 的路由,路由是干什么用的?
-
在Vue里,路由就像“页面导航的指挥官”,专门负责管理页面(组件)的切换和显示逻辑。简单来说,它能让单页应用(SPA)像多页应用一样实现“不同URL对应不同页面”的效果,但整个过程不会刷新网页。一、路...
- Vue3项目投屏功能开发!(vue投票功能)
-
最近接了个大屏项目,产品想在不同的显示器上展示大屏项目不同的页面,做出来的效果图大概长这样...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)