百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

使用Tensorflow开发一维生成对抗网络

ztj100 2024-11-21 00:30 12 浏览 0 评论

生成式对抗网络是一种用于训练生成器模型的深度学习体系结构。GAN由两个模型组成,一个称为生成器(Generator),另一个称为判别器(Discriminator)。顾名思义,生成器生成新样本,判别器负责对生成的样本进行真伪分类。

GAN实际如何运作的?

判别器模型的性能用于更新生成器和判别器本身的网络权重。生成器实际上从未看到过数据,而是根据判别器的性能不断地进行调整,更具体地说,是根据从判别器传回来的误差梯度进行调整。生成器逐渐学会通过产生与真实样本完全相同的样本来欺骗判别器。

在这篇文章中,我们将选择一个简单的一维函数来直观地理解GAN。本文分为5个部分:

  1. 选择一个一维函数
  2. 实现判别器模型
  3. 实现生成器模型
  4. 训练GAN模型
  5. 性能评估

1.一维函数

我们需要选择一个一维函数来制作模型。一维函数的形式为

y = f(x),其中x是输入,y是对应的输出

为简单起见,我将使用函数y = x2。您可以自由选择任何函数。我们将保持输入在-0.5和+0.5之间。下面给出了一个计算输入的简单函数 :

# Importing Required Packages
import numpy as np
from numpy.random import rand,randn
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential

# n : number of samples
# type : real / fake
def get_real_samples(n):
    X1 = rand(n)-0.5
    X2=X1*X1
    X1=X1.reshape(n,1)
    X2=X2.reshape(n,1)
    X=np.hstack((X1,X2))
    y=np.ones((n,1))
    return X,y

该函数简单地接受N个随机值,并将每个值减去0.5,以便将输入范围保持在-0.5和+0.5之间。当为real时y=1,当为fake时y=0。。

2.判别器模型

判别器只是一个简单的分类模型,它可以预测样本是real还是fake。判别器将两个实数值的样本作为输入,并输出样本是real还是fake。我们处理的问题非常简单,所以我们不需要非常复杂的神经网络,我们将只采用一个隐藏层,其中有25个节点。您可以自由地试验节点数或层数,以提高生成器的准确性。我们将对隐藏层使用ReLu激活,对输出层使用sigmoid激活。Python实现如下:

# Make discriminator model (classifies real/fake inputs)
def disc_model(input_dim=2):
    model=Sequential()
    #hidden layer
    model.add(Dense(25,activation='relu',input_dim=input_dim))
    #outputlayer
    model.add(Dense(1,activation='sigmoid'))
    #compile model
    model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
    return model

3.生成器模型

对于生成器,我们将噪声输入提供给生成器,此噪声输入也称为潜在变量。

潜在变量是潜在空间中的隐藏变量或未观察到的变量,潜在空间是这些变量的多维空间。

直到我们的生成器受到训练并赋予这些点意义,该潜在空间才有意义,这些点被映射到判别器的输入。我们将定义一个3维的潜在空间(可以更改维数),并实验生成器的行为和准确度如何变化。我们将对潜在空间中的每个变量使用高斯分布。生成器使用一个隐藏层,该隐藏层将由15个具有ReLu激活函数的神经元组成。输出层将由两个神经元组成,这两个神经元将连接到判别器层的输入。Python实现如下:

# This function generates fake points with help of generator
def get_fake_samples(gen,noise_dim,n):
    #generating random points from latent space
    x_input = noise_points(noise_dim,n)
    # X is the output from the generator
    X = gen.predict(x_input)
    # Since these are fake samples so y=0
    y=np.zeros((n,1))
    return X,y

# generate noise points from latent space (latent variables)
def noise_points(noise_dim,n):
    noise = randn(n*noise_dim)
    noise=noise.reshape(n,noise_dim)
    return noise

# Make generator model (generates fake inputs)
def gen_model(input_dim,output_dim=2):
    model=Sequential()
    model.add(Dense(15,activation='relu',input_dim=input_dim))
    model.add(Dense(output_dim,activation='linear'))
    return model

4.训练GAN模型

训练GAN模型的方法有很多,最简单的方法是创建一个新的模型,该模型由生成器和判别器两部分组成。我们只是在逻辑上封装了生成器和判别器网络。我们将把GAN模型作为一个整体进行训练,这样来自判别器的反向传播误差也会更新生成器的权重。如果判别器能够很好的进行分类,那么生成器的权重将更新地更多;如果判别器不能很好的进行分类,那么生成器的权重将更新得少一些。这样,在生成器和判别器之间就形成了一种对抗关系。Python代码如下:

# Logical model for connecting generator and discriminator
def gan_model(disc,gen):
    disc.trainable=False
    model=Sequential()
    model.add(gen)
    model.add(disc)
    model.compile(loss='binary_crossentropy',optimizer='adam')
    return model

判别器模型的可训练属性被设置为false,这样就可以仅对standalone模型进行训练。

现在我们只剩下对GAN模型进行整体训练了。我们将编写一个函数来做这个的事情。该函数将运行10000个epochs,每运行2000个epochs,它将评估判别器和生成器的性能。Python代码如下:

def train(g_model,d_model,gan_model,noise_dim,epochs=10000,batch_size=256,n_eval=2000):
    half_batch=batch_size//2
    for i in range(epochs):
        x_real,y_real = get_real_samples(half_batch)
        x_fake,y_fake = get_fake_samples(g_model,noise_dim,half_batch)
        d_model.train_on_batch(x_real,y_real)
        d_model.train_on_batch(x_fake,y_fake)
        x_gan = noise_points(noise_dim,batch_size)
        y = np.ones((batch_size,1))
        gan_model.train_on_batch(x_gan,y)
        if (i+1) % n_eval ==0:
            show_performance(i+1,g_model,d_model,noise_dim)

5.评估性能

在每隔一定的epochs之后,我们将调用show_performance函数,该函数将从生成器中获取真实样本和虚假样本并预测结果。我们还将在散点图上绘制结果,以便我们可以查看GAN的性能。Python实现如下:

# Model evaluation function
def show_performance(epoch,g_model,d_model,noise_dim,n=100):
    x_real,y_real = get_real_samples(n)
    _,real_acc = d_model.evaluate(x_real,y_real,verbose=0)
    x_fake,y_fake = get_fake_samples(g_model,noise_dim,n)
    _,fake_acc = d_model.evaluate(x_fake,y_fake,verbose=0)
    print(epoch,real_acc,fake_acc)
    plt.figure(figsize=(20,10))
    plt.scatter(x_real[:,0],x_real[:,1],color='red')
    plt.scatter(x_fake[:,0],x_fake[:,1],color='blue')

noise_dim=5
gen = gen_model(noise_dim)
disc = disc_model()
gan = gan_model(disc,gen)
train(gen,disc,gan,noise_dim)

在epoch = 2000之后,我们得到了散点图如下,您的图可能会有所不同。

红点表示real点,蓝点表示生成器生成的点。我们可以看到,蓝点已开始呈y =x2的形状。

如果我们继续进行10000个epochs,您将得到类似下面的图像。您可以尝试使用更多个epochs(例如15000或20000个epochs)来获得更好的准确性。

现在我们可以看到,我们已经从生成器中得到了一个更确定的样本,我们可以说生成器已经学习并拟合了这个函数。也就是说,仅仅通过误差梯度,生成器就学会了这个函数。

相关推荐

利用navicat将postgresql转为mysql

导航"拿来主义"吃得亏自己动手,丰衣足食...

Navicat的详细教程「偷偷收藏」(navicatlite)

Navicat是一套快速、可靠并价格适宜的数据库管理工具,适用于三种平台:Windows、macOS及Linux。可以用来对本机或远程的MySQL、SQLServer、SQLite、...

Linux系统安装SQL Server数据库(linux安装数据库命令)

一、官方说明...

Navicat推出免费数据库管理软件Premium Lite

IT之家6月26日消息,Navicat推出一款免费的数据库管理开发工具——NavicatPremiumLite,针对入门级用户,支持基础的数据库管理和协同合作功能。▲Navicat...

Docker安装部署Oracle/Sql Server

一、Docker安装Oracle12cOracle简介...

Docker安装MS SQL Server并使用Navicat远程连接

...

Web性能的计算方式与优化方案(二)

通过前面《...

网络入侵检测系统之Suricata(十四)——匹配流程

其实规则的匹配流程和加载流程是强相关的,你如何组织规则那么就会采用该种数据结构去匹配,例如你用radixtree组织海量ip规则,那么匹配的时候也是采用bittest确定前缀节点,然后逐一左右子树...

使用deepseek写一个图片转换代码(deepnode处理图片)

写一个photoshop代码,要求:可以将文件夹里面的图片都处理成CMYK模式。软件版本:photoshop2022,然后生成的代码如下://Photoshop2022CMYK批量转换专业版脚...

AI助力AUTOCAD,生成LSP插件(ai里面cad插件怎么使用)

以下是用AI生成的,用AUTOLISP语言编写的cad插件,分享给大家:一、将单线偏移为双线;;;;;;;;;;;;;;;;;;;;;;单线变双线...

Core Audio音频基础概述(core 音乐)

1、CoreAudioCoreAudio提供了数字音频服务为iOS与OSX,它提供了一系列框架去处理音频....

BlazorUI 组件库——反馈与弹层 (1)

组件是前端的基础。组件库也是前端框架的核心中的重点。组件库中有一个重要的板块:反馈与弹层!反馈与弹层在组件形态上,与Button、Input类等嵌入界面的组件有所不同,通常以层的形式出现。本篇文章...

怎样创建一个Xcode插件(xcode如何新建一个main.c)

译者:@yohunl译者注:原文使用的是xcode6.3.2,我翻译的时候,使用的是xcode7.2.1,经过验证,本部分中说的依然是有效的.在文中你可以学习到一系列的技能,非常值得一看.这些技能不单...

让SSL/TLS协议流行起来:深度解读SSL/TLS实现1

一前言SSL/TLS协议是网络安全通信的重要基石,本系列将简单介绍SSL/TLS协议,主要关注SSL/TLS协议的安全性,特别是SSL规范的正确实现。本系列的文章大体分为3个部分:SSL/TLS协...

社交软件开发6-客户端开发-ios端开发验证登陆部分

欢迎订阅我的头条号:一点热上一节说到,Android客户端的开发,主要是编写了,如何使用Androidstudio如何创建一个Android项目,已经使用gradle来加载第三方库,并且使用了异步...

取消回复欢迎 发表评论: