百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

使用Tensorflow开发一维生成对抗网络

ztj100 2024-11-21 00:30 19 浏览 0 评论

生成式对抗网络是一种用于训练生成器模型的深度学习体系结构。GAN由两个模型组成,一个称为生成器(Generator),另一个称为判别器(Discriminator)。顾名思义,生成器生成新样本,判别器负责对生成的样本进行真伪分类。

GAN实际如何运作的?

判别器模型的性能用于更新生成器和判别器本身的网络权重。生成器实际上从未看到过数据,而是根据判别器的性能不断地进行调整,更具体地说,是根据从判别器传回来的误差梯度进行调整。生成器逐渐学会通过产生与真实样本完全相同的样本来欺骗判别器。

在这篇文章中,我们将选择一个简单的一维函数来直观地理解GAN。本文分为5个部分:

  1. 选择一个一维函数
  2. 实现判别器模型
  3. 实现生成器模型
  4. 训练GAN模型
  5. 性能评估

1.一维函数

我们需要选择一个一维函数来制作模型。一维函数的形式为

y = f(x),其中x是输入,y是对应的输出

为简单起见,我将使用函数y = x2。您可以自由选择任何函数。我们将保持输入在-0.5和+0.5之间。下面给出了一个计算输入的简单函数 :

# Importing Required Packages
import numpy as np
from numpy.random import rand,randn
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential

# n : number of samples
# type : real / fake
def get_real_samples(n):
    X1 = rand(n)-0.5
    X2=X1*X1
    X1=X1.reshape(n,1)
    X2=X2.reshape(n,1)
    X=np.hstack((X1,X2))
    y=np.ones((n,1))
    return X,y

该函数简单地接受N个随机值,并将每个值减去0.5,以便将输入范围保持在-0.5和+0.5之间。当为real时y=1,当为fake时y=0。。

2.判别器模型

判别器只是一个简单的分类模型,它可以预测样本是real还是fake。判别器将两个实数值的样本作为输入,并输出样本是real还是fake。我们处理的问题非常简单,所以我们不需要非常复杂的神经网络,我们将只采用一个隐藏层,其中有25个节点。您可以自由地试验节点数或层数,以提高生成器的准确性。我们将对隐藏层使用ReLu激活,对输出层使用sigmoid激活。Python实现如下:

# Make discriminator model (classifies real/fake inputs)
def disc_model(input_dim=2):
    model=Sequential()
    #hidden layer
    model.add(Dense(25,activation='relu',input_dim=input_dim))
    #outputlayer
    model.add(Dense(1,activation='sigmoid'))
    #compile model
    model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
    return model

3.生成器模型

对于生成器,我们将噪声输入提供给生成器,此噪声输入也称为潜在变量。

潜在变量是潜在空间中的隐藏变量或未观察到的变量,潜在空间是这些变量的多维空间。

直到我们的生成器受到训练并赋予这些点意义,该潜在空间才有意义,这些点被映射到判别器的输入。我们将定义一个3维的潜在空间(可以更改维数),并实验生成器的行为和准确度如何变化。我们将对潜在空间中的每个变量使用高斯分布。生成器使用一个隐藏层,该隐藏层将由15个具有ReLu激活函数的神经元组成。输出层将由两个神经元组成,这两个神经元将连接到判别器层的输入。Python实现如下:

# This function generates fake points with help of generator
def get_fake_samples(gen,noise_dim,n):
    #generating random points from latent space
    x_input = noise_points(noise_dim,n)
    # X is the output from the generator
    X = gen.predict(x_input)
    # Since these are fake samples so y=0
    y=np.zeros((n,1))
    return X,y

# generate noise points from latent space (latent variables)
def noise_points(noise_dim,n):
    noise = randn(n*noise_dim)
    noise=noise.reshape(n,noise_dim)
    return noise

# Make generator model (generates fake inputs)
def gen_model(input_dim,output_dim=2):
    model=Sequential()
    model.add(Dense(15,activation='relu',input_dim=input_dim))
    model.add(Dense(output_dim,activation='linear'))
    return model

4.训练GAN模型

训练GAN模型的方法有很多,最简单的方法是创建一个新的模型,该模型由生成器和判别器两部分组成。我们只是在逻辑上封装了生成器和判别器网络。我们将把GAN模型作为一个整体进行训练,这样来自判别器的反向传播误差也会更新生成器的权重。如果判别器能够很好的进行分类,那么生成器的权重将更新地更多;如果判别器不能很好的进行分类,那么生成器的权重将更新得少一些。这样,在生成器和判别器之间就形成了一种对抗关系。Python代码如下:

# Logical model for connecting generator and discriminator
def gan_model(disc,gen):
    disc.trainable=False
    model=Sequential()
    model.add(gen)
    model.add(disc)
    model.compile(loss='binary_crossentropy',optimizer='adam')
    return model

判别器模型的可训练属性被设置为false,这样就可以仅对standalone模型进行训练。

现在我们只剩下对GAN模型进行整体训练了。我们将编写一个函数来做这个的事情。该函数将运行10000个epochs,每运行2000个epochs,它将评估判别器和生成器的性能。Python代码如下:

def train(g_model,d_model,gan_model,noise_dim,epochs=10000,batch_size=256,n_eval=2000):
    half_batch=batch_size//2
    for i in range(epochs):
        x_real,y_real = get_real_samples(half_batch)
        x_fake,y_fake = get_fake_samples(g_model,noise_dim,half_batch)
        d_model.train_on_batch(x_real,y_real)
        d_model.train_on_batch(x_fake,y_fake)
        x_gan = noise_points(noise_dim,batch_size)
        y = np.ones((batch_size,1))
        gan_model.train_on_batch(x_gan,y)
        if (i+1) % n_eval ==0:
            show_performance(i+1,g_model,d_model,noise_dim)

5.评估性能

在每隔一定的epochs之后,我们将调用show_performance函数,该函数将从生成器中获取真实样本和虚假样本并预测结果。我们还将在散点图上绘制结果,以便我们可以查看GAN的性能。Python实现如下:

# Model evaluation function
def show_performance(epoch,g_model,d_model,noise_dim,n=100):
    x_real,y_real = get_real_samples(n)
    _,real_acc = d_model.evaluate(x_real,y_real,verbose=0)
    x_fake,y_fake = get_fake_samples(g_model,noise_dim,n)
    _,fake_acc = d_model.evaluate(x_fake,y_fake,verbose=0)
    print(epoch,real_acc,fake_acc)
    plt.figure(figsize=(20,10))
    plt.scatter(x_real[:,0],x_real[:,1],color='red')
    plt.scatter(x_fake[:,0],x_fake[:,1],color='blue')

noise_dim=5
gen = gen_model(noise_dim)
disc = disc_model()
gan = gan_model(disc,gen)
train(gen,disc,gan,noise_dim)

在epoch = 2000之后,我们得到了散点图如下,您的图可能会有所不同。

红点表示real点,蓝点表示生成器生成的点。我们可以看到,蓝点已开始呈y =x2的形状。

如果我们继续进行10000个epochs,您将得到类似下面的图像。您可以尝试使用更多个epochs(例如15000或20000个epochs)来获得更好的准确性。

现在我们可以看到,我们已经从生成器中得到了一个更确定的样本,我们可以说生成器已经学习并拟合了这个函数。也就是说,仅仅通过误差梯度,生成器就学会了这个函数。

相关推荐

sharding-jdbc实现`分库分表`与`读写分离`

一、前言本文将基于以下环境整合...

三分钟了解mysql中主键、外键、非空、唯一、默认约束是什么

在数据库中,数据表是数据库中最重要、最基本的操作对象,是数据存储的基本单位。数据表被定义为列的集合,数据在表中是按照行和列的格式来存储的。每一行代表一条唯一的记录,每一列代表记录中的一个域。...

MySQL8行级锁_mysql如何加行级锁

MySQL8行级锁版本:8.0.34基本概念...

mysql使用小技巧_mysql使用入门

1、MySQL中有许多很实用的函数,好好利用它们可以省去很多时间:group_concat()将取到的值用逗号连接,可以这么用:selectgroup_concat(distinctid)fr...

MySQL/MariaDB中如何支持全部的Unicode?

永远不要在MySQL中使用utf8,并且始终使用utf8mb4。utf8mb4介绍MySQL/MariaDB中,utf8字符集并不是对Unicode的真正实现,即不是真正的UTF-8编码,因...

聊聊 MySQL Server 可执行注释,你懂了吗?

前言MySQLServer当前支持如下3种注释风格:...

MySQL系列-源码编译安装(v5.7.34)

一、系统环境要求...

MySQL的锁就锁住我啦!与腾讯大佬的技术交谈,是我小看它了

对酒当歌,人生几何!朝朝暮暮,唯有己脱。苦苦寻觅找工作之间,殊不知今日之事乃我心之痛,难道是我不配拥有工作嘛。自面试后他所谓的等待都过去一段时日,可惜在下京东上的小金库都要见低啦。每每想到不由心中一...

MySQL字符问题_mysql中字符串的位置

中文写入乱码问题:我输入的中文编码是urf8的,建的库是urf8的,但是插入mysql总是乱码,一堆"???????????????????????"我用的是ibatis,终于找到原因了,我是这么解决...

深圳尚学堂:mysql基本sql语句大全(三)

数据开发-经典1.按姓氏笔画排序:Select*FromTableNameOrderByCustomerNameCollateChinese_PRC_Stroke_ci_as//从少...

MySQL进行行级锁的?一会next-key锁,一会间隙锁,一会记录锁?

大家好,是不是很多人都对MySQL加行级锁的规则搞的迷迷糊糊,一会是next-key锁,一会是间隙锁,一会又是记录锁。坦白说,确实还挺复杂的,但是好在我找点了点规律,也知道如何如何用命令分析加...

一文讲清怎么利用Python Django实现Excel数据表的导入导出功能

摘要:Python作为一门简单易学且功能强大的编程语言,广受程序员、数据分析师和AI工程师的青睐。本文系统讲解了如何使用Python的Django框架结合openpyxl库实现Excel...

用DataX实现两个MySQL实例间的数据同步

DataXDataX使用Java实现。如果可以实现数据库实例之间准实时的...

MySQL数据库知识_mysql数据库基础知识

MySQL是一种关系型数据库管理系统;那废话不多说,直接上自己以前学习整理文档:查看数据库命令:(1).查看存储过程状态:showprocedurestatus;(2).显示系统变量:show...

如何为MySQL中的JSON字段设置索引

背景MySQL在2015年中发布的5.7.8版本中首次引入了JSON数据类型。自此,它成了一种逃离严格列定义的方式,可以存储各种形状和大小的JSON文档,例如审计日志、配置信息、第三方数据包、用户自定...

取消回复欢迎 发表评论: