百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

从零教你写一个完整的GAN(如何写个)

ztj100 2024-10-28 21:15 21 浏览 0 评论

导言

啦啦啦,现今 GAN 算法可以算作 ML 领域下比较热门的一个方向。事实上,GAN 已经作为一种思想来渗透在 ML 的其余领域,从而做出了很多很 Amazing 的东西。比如结合卷积神经网络,可以用于生成图片。或者结合 NLP,可以生成特定风格的短句子。(比如川普风格的 twitter......)

可惜的是,网络上很多老司机开 GAN 的车最后都翻了,大多只是翻译了一篇论文,一旦涉及算法实现部分就直接放开源的实现地址,而那些开源的东东,缺少了必要的引导,实在对于新手来说很是懵逼。所以兔子哥哥带着开好车,开稳车的心态,特定来带一下各位想入门 GAN 的其他小兔兔们来飞一会。

GAN 的介绍与训练

先来阐述一下 GAN 的基本做法,这里不摆公式,因为你听完后,该怎么搭建和怎么训练你心里应该有数了。

首先,GAN 全称为 GEnerative Adversarial Nets(生成对抗网络), 其构成分为两部份:

Generator(生成器),下文简称 G

DIScriminator(辨别器), 下文简称 D。

在本文,为了方便小兔兔理解,使用一个较为简单,也是 GAN 论文提及到的例子,训练 G 生成符合指定均值和标准差的数据,在这里,我们指定MEAN=4,STD=1.5的高斯分布(正态分布)。

这货的样子大概长这样

下面是数据生成的代码:

def sample_data(size, length=100):

""" 随机mean=4 std=1.5的数据 :param size: :param length: :return: """

data =

for _ in range(size):

data.aPPend(sorted(np.random.normal(4, 1.5, length)))

return np.array(data)

在生成高斯分布的数据后,我们还对数据进行了排序,这时因为排序后的训练会相对平滑。具体原因看这个 [Generative Adversarial Nets in TensorFlow (Part I)]

这一段是生成噪音的代码,既然是噪音,那么我们只需要随机产生 0~1 的数据就好。

def random_data(size, length=100):

""" 随机生成数据 :param size: :param length: :return: """

x = np.random.random(length)

data.append(x)

return np.array(data)

随机分布的数据长这样

接下来就是开撸 GAN 了。

首先的一点就是,我们需要确定 G, 和 D 的具体结构,这里因为本文的例子太过于入门级,并不需要使用到复杂的神经网络结构,比如卷积层和递归层,这里 G 和 D 只需全连接的神经网络就好。全连接层的神经网络本质就是矩阵的花式相乘。为神马说是花式相乘呢,因为大多数时候,我们在矩阵相乘的结果后面会添加不同的激活函数。

G 和 D 分别为三层的全链接的神经网络,其中 G 的激活函数分别为,relu,sigmoid,liner,这里前两层只是因为考虑到数据的非线性转换,并没有什么特别选择这两个激活函数的原因。其次 D 的三层分别为 relu,sigmoid,sigmoid。

接下来就引出 GAN 的训练问题。GAN 的思想源于博弈论,直白一点就是套路与反套路。

先从 D 开始分析,D 作为辨别器,它的职责就是区分于真实的高斯分布和 G 生成的” 假” 高斯分布。所以很显然,针对 D 来说,其需要解决的就是传统的二分类问题。

在二分类问题中,我们习惯用交叉熵来衡量分类效果。

从公式中不难看出,在全部分类正确时,交叉熵会接近于 0,因此,我们的目标就是通过拟合 D 的参数来最小化交叉熵的值。

D 既然是传统的二分类问题,那么 D 的训练过程也很容易得出了

即先把真实数据标识为‘1’(真实分布),由生成器生成的数据标识为’0‘(生成分布),反复迭代训练 D ------------ (1)

说 G 的训练之前先来打个比方,假如一男一女在一起了,现在两人性格出现矛盾了,女生并不愿意改变,但两个人都想继续在一起,这时,唯一的方法就是男生改变了。先忽略现实生活的问题,但从举例的角度说,显然久而久之男生就会变得更加 fit 这个女生。

G 的训练也是如此:

先将 G 拼接在 D 的上方,即 G 的输出作为 D 的输入(男生女生在一起),而同时固定 D 的参数(女生不愿意改变),并将进入 G 的噪音样本标签全部改成'1'(目标是两个人继续在一起,没有其他选择),为了最小化损失函数,此时就只能改变 G 的每一层权重,反复迭代后 G 的生成能力因此得以改进。(男生更适合女生) ------------ (2)

反复迭代(1)(2),最终 G 就会得到较好的生成能力。

补充一点,在训练 D 的时候,我曾把数据直接放进去,这样的后果是最后生成的数据,能学习到高斯分布的轮廓,但标准差和均值则和真实样本相差很大。因此,这里我建议直接使用平均值和标准差作为 D 的输入。

这使得 D 在训练前需要对数据进行预处理。

def preprocess_data(x):

""" 计算每一组数据平均值和方差 :param x: :return: """

return [[np.mean(data), np.std(data)] for data in x]

G 和 D 的连接之间也需要做出处理。

# 先求出G_output3的各行平均值和方差

MEAN = tf.reduce_mean(G_output3, 1) # 平均值,但是是1D向量

MEAN_T = tf.transpose(tf.expand_dims(MEAN, 0)) # 转置

STD = tf.sqrt(tf.reduce_mean(tf.square(G_output3 - MEAN_T), 1))

DATA = tf.conCAT(1, [MEAN_T,

tf.transpose(tf.expand_dims(STD, 0))] # 拼接起来

以下是损失函数变化图:

蓝色是 D 单独作二分类问题处理时的变化

绿色是拼接 G 在 D 的上方后损失函数的变化

不难看出,两者在经历反复震荡 (互相博弈而导致),最后稳定于 0.5 附近,这时我们可以认为,G 的生成能力已经达到以假乱真,D 再也不能分出真假。

接下来的这个就是 D-G 博弈 200 次后的结果:

绿色是真实分布

蓝色是噪音原本的分布

红色是生成分布

后话

兔子哥哥的车这次就开到这里了。作为一个大三且数学能力较为一般的学生, 从比较感性的角度来描述了一次 GAN 的基本过程,有说得不对地方请各位见谅和指点。

如果各位读者需要更严格的数学公式和证明,可以阅读 GAN 的开山之作([1406.2661] Generative Adversarial Networks) , 另外本文提及的代码都可在这里找到(MashiMaroLjc/learn-GAN),有需要的童鞋也可以私信交流。

这就是全部内容了,下次心情好的话怼 DCGAN,看看能不能生成点好玩的图片,嗯~ 睡觉去~

雷锋网(公众号:雷锋网)按:本文原作者兔子老大,原文来自他的知乎专栏

雷锋网版权文章,未经授权禁止转载。详情见转载须知。

相关推荐

再说圆的面积-蒙特卡洛(蒙特卡洛方法求圆周率的matlab程序)

在微积分-圆的面积和周长(1)介绍微积分方法求解圆的面积,本文使用蒙特卡洛方法求解圆面积。...

python编程:如何使用python代码绘制出哪些常见的机器学习图像?

专栏推荐...

python创建分类器小结(pytorch分类数据集创建)

简介:分类是指利用数据的特性将其分成若干类型的过程。监督学习分类器就是用带标记的训练数据建立一个模型,然后对未知数据进行分类。...

matplotlib——绘制散点图(matplotlib散点图颜色和图例)

绘制散点图不同条件(维度)之间的内在关联关系观察数据的离散聚合程度...

python实现实时绘制数据(python如何绘制)

方法一importmatplotlib.pyplotaspltimportnumpyasnpimporttimefrommathimport*plt.ion()#...

简单学Python——matplotlib库3——绘制散点图

前面我们学习了用matplotlib绘制折线图,今天我们学习绘制散点图。其实简单的散点图与折线图的语法基本相同,只是作图函数由plot()变成了scatter()。下面就绘制一个散点图:import...

数据分析-相关性分析可视化(相关性分析数据处理)

前面介绍了相关性分析的原理、流程和常用的皮尔逊相关系数和斯皮尔曼相关系数,具体可以参考...

免费Python机器学习课程一:线性回归算法

学习线性回归的概念并从头开始在python中开发完整的线性回归算法最基本的机器学习算法必须是具有单个变量的线性回归算法。如今,可用的高级机器学习算法,库和技术如此之多,以至于线性回归似乎并不重要。但是...

用Python进行机器学习(2)之逻辑回归

前面介绍了线性回归,本次介绍的是逻辑回归。逻辑回归虽然名字里面带有“回归”两个字,但是它是一种分类算法,通常用于解决二分类问题,比如某个邮件是否是广告邮件,比如某个评价是否为正向的评价。逻辑回归也可以...

【Python机器学习系列】拟合和回归傻傻分不清?一文带你彻底搞懂

一、拟合和回归的区别拟合...

推荐2个十分好用的pandas数据探索分析神器

作者:俊欣来源:关于数据分析与可视化...

向量数据库:解锁大模型记忆的关键!选型指南+实战案例全解析

本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在...

用Python进行机器学习(11)-主成分分析PCA

我们在机器学习中有时候需要处理很多个参数,但是这些参数有时候彼此之间是有着各种关系的,这个时候我们就会想:是否可以找到一种方式来降低参数的个数呢?这就是今天我们要介绍的主成分分析,英文是Princip...

神经网络基础深度解析:从感知机到反向传播

本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在...

Python实现基于机器学习的RFM模型

CDA数据分析师出品作者:CDALevelⅠ持证人岗位:数据分析师行业:大数据...

取消回复欢迎 发表评论: