使用Tensorflow开发一维生成对抗网络
ztj100 2024-11-21 00:30 12 浏览 0 评论
生成式对抗网络是一种用于训练生成器模型的深度学习体系结构。GAN由两个模型组成,一个称为生成器(Generator),另一个称为判别器(Discriminator)。顾名思义,生成器生成新样本,判别器负责对生成的样本进行真伪分类。
GAN实际如何运作的?
判别器模型的性能用于更新生成器和判别器本身的网络权重。生成器实际上从未看到过数据,而是根据判别器的性能不断地进行调整,更具体地说,是根据从判别器传回来的误差梯度进行调整。生成器逐渐学会通过产生与真实样本完全相同的样本来欺骗判别器。
在这篇文章中,我们将选择一个简单的一维函数来直观地理解GAN。本文分为5个部分:
- 选择一个一维函数
- 实现判别器模型
- 实现生成器模型
- 训练GAN模型
- 性能评估
1.一维函数
我们需要选择一个一维函数来制作模型。一维函数的形式为
y = f(x),其中x是输入,y是对应的输出
为简单起见,我将使用函数y = x2。您可以自由选择任何函数。我们将保持输入在-0.5和+0.5之间。下面给出了一个计算输入的简单函数 :
# Importing Required Packages
import numpy as np
from numpy.random import rand,randn
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
# n : number of samples
# type : real / fake
def get_real_samples(n):
X1 = rand(n)-0.5
X2=X1*X1
X1=X1.reshape(n,1)
X2=X2.reshape(n,1)
X=np.hstack((X1,X2))
y=np.ones((n,1))
return X,y
该函数简单地接受N个随机值,并将每个值减去0.5,以便将输入范围保持在-0.5和+0.5之间。当为real时y=1,当为fake时y=0。。
2.判别器模型
判别器只是一个简单的分类模型,它可以预测样本是real还是fake。判别器将两个实数值的样本作为输入,并输出样本是real还是fake。我们处理的问题非常简单,所以我们不需要非常复杂的神经网络,我们将只采用一个隐藏层,其中有25个节点。您可以自由地试验节点数或层数,以提高生成器的准确性。我们将对隐藏层使用ReLu激活,对输出层使用sigmoid激活。Python实现如下:
# Make discriminator model (classifies real/fake inputs)
def disc_model(input_dim=2):
model=Sequential()
#hidden layer
model.add(Dense(25,activation='relu',input_dim=input_dim))
#outputlayer
model.add(Dense(1,activation='sigmoid'))
#compile model
model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
return model
3.生成器模型
对于生成器,我们将噪声输入提供给生成器,此噪声输入也称为潜在变量。
潜在变量是潜在空间中的隐藏变量或未观察到的变量,潜在空间是这些变量的多维空间。
直到我们的生成器受到训练并赋予这些点意义,该潜在空间才有意义,这些点被映射到判别器的输入。我们将定义一个3维的潜在空间(可以更改维数),并实验生成器的行为和准确度如何变化。我们将对潜在空间中的每个变量使用高斯分布。生成器使用一个隐藏层,该隐藏层将由15个具有ReLu激活函数的神经元组成。输出层将由两个神经元组成,这两个神经元将连接到判别器层的输入。Python实现如下:
# This function generates fake points with help of generator
def get_fake_samples(gen,noise_dim,n):
#generating random points from latent space
x_input = noise_points(noise_dim,n)
# X is the output from the generator
X = gen.predict(x_input)
# Since these are fake samples so y=0
y=np.zeros((n,1))
return X,y
# generate noise points from latent space (latent variables)
def noise_points(noise_dim,n):
noise = randn(n*noise_dim)
noise=noise.reshape(n,noise_dim)
return noise
# Make generator model (generates fake inputs)
def gen_model(input_dim,output_dim=2):
model=Sequential()
model.add(Dense(15,activation='relu',input_dim=input_dim))
model.add(Dense(output_dim,activation='linear'))
return model
4.训练GAN模型
训练GAN模型的方法有很多,最简单的方法是创建一个新的模型,该模型由生成器和判别器两部分组成。我们只是在逻辑上封装了生成器和判别器网络。我们将把GAN模型作为一个整体进行训练,这样来自判别器的反向传播误差也会更新生成器的权重。如果判别器能够很好的进行分类,那么生成器的权重将更新地更多;如果判别器不能很好的进行分类,那么生成器的权重将更新得少一些。这样,在生成器和判别器之间就形成了一种对抗关系。Python代码如下:
# Logical model for connecting generator and discriminator
def gan_model(disc,gen):
disc.trainable=False
model=Sequential()
model.add(gen)
model.add(disc)
model.compile(loss='binary_crossentropy',optimizer='adam')
return model
判别器模型的可训练属性被设置为false,这样就可以仅对standalone模型进行训练。
现在我们只剩下对GAN模型进行整体训练了。我们将编写一个函数来做这个的事情。该函数将运行10000个epochs,每运行2000个epochs,它将评估判别器和生成器的性能。Python代码如下:
def train(g_model,d_model,gan_model,noise_dim,epochs=10000,batch_size=256,n_eval=2000):
half_batch=batch_size//2
for i in range(epochs):
x_real,y_real = get_real_samples(half_batch)
x_fake,y_fake = get_fake_samples(g_model,noise_dim,half_batch)
d_model.train_on_batch(x_real,y_real)
d_model.train_on_batch(x_fake,y_fake)
x_gan = noise_points(noise_dim,batch_size)
y = np.ones((batch_size,1))
gan_model.train_on_batch(x_gan,y)
if (i+1) % n_eval ==0:
show_performance(i+1,g_model,d_model,noise_dim)
5.评估性能
在每隔一定的epochs之后,我们将调用show_performance函数,该函数将从生成器中获取真实样本和虚假样本并预测结果。我们还将在散点图上绘制结果,以便我们可以查看GAN的性能。Python实现如下:
# Model evaluation function
def show_performance(epoch,g_model,d_model,noise_dim,n=100):
x_real,y_real = get_real_samples(n)
_,real_acc = d_model.evaluate(x_real,y_real,verbose=0)
x_fake,y_fake = get_fake_samples(g_model,noise_dim,n)
_,fake_acc = d_model.evaluate(x_fake,y_fake,verbose=0)
print(epoch,real_acc,fake_acc)
plt.figure(figsize=(20,10))
plt.scatter(x_real[:,0],x_real[:,1],color='red')
plt.scatter(x_fake[:,0],x_fake[:,1],color='blue')
noise_dim=5
gen = gen_model(noise_dim)
disc = disc_model()
gan = gan_model(disc,gen)
train(gen,disc,gan,noise_dim)
在epoch = 2000之后,我们得到了散点图如下,您的图可能会有所不同。
红点表示real点,蓝点表示生成器生成的点。我们可以看到,蓝点已开始呈y =x2的形状。
如果我们继续进行10000个epochs,您将得到类似下面的图像。您可以尝试使用更多个epochs(例如15000或20000个epochs)来获得更好的准确性。
现在我们可以看到,我们已经从生成器中得到了一个更确定的样本,我们可以说生成器已经学习并拟合了这个函数。也就是说,仅仅通过误差梯度,生成器就学会了这个函数。
相关推荐
- 利用navicat将postgresql转为mysql
-
导航"拿来主义"吃得亏自己动手,丰衣足食...
- Navicat的详细教程「偷偷收藏」(navicatlite)
-
Navicat是一套快速、可靠并价格适宜的数据库管理工具,适用于三种平台:Windows、macOS及Linux。可以用来对本机或远程的MySQL、SQLServer、SQLite、...
- Linux系统安装SQL Server数据库(linux安装数据库命令)
-
一、官方说明...
- Navicat推出免费数据库管理软件Premium Lite
-
IT之家6月26日消息,Navicat推出一款免费的数据库管理开发工具——NavicatPremiumLite,针对入门级用户,支持基础的数据库管理和协同合作功能。▲Navicat...
- Docker安装部署Oracle/Sql Server
-
一、Docker安装Oracle12cOracle简介...
- Web性能的计算方式与优化方案(二)
-
通过前面《...
- 网络入侵检测系统之Suricata(十四)——匹配流程
-
其实规则的匹配流程和加载流程是强相关的,你如何组织规则那么就会采用该种数据结构去匹配,例如你用radixtree组织海量ip规则,那么匹配的时候也是采用bittest确定前缀节点,然后逐一左右子树...
- 使用deepseek写一个图片转换代码(deepnode处理图片)
-
写一个photoshop代码,要求:可以将文件夹里面的图片都处理成CMYK模式。软件版本:photoshop2022,然后生成的代码如下://Photoshop2022CMYK批量转换专业版脚...
- AI助力AUTOCAD,生成LSP插件(ai里面cad插件怎么使用)
-
以下是用AI生成的,用AUTOLISP语言编写的cad插件,分享给大家:一、将单线偏移为双线;;;;;;;;;;;;;;;;;;;;;;单线变双线...
- Core Audio音频基础概述(core 音乐)
-
1、CoreAudioCoreAudio提供了数字音频服务为iOS与OSX,它提供了一系列框架去处理音频....
- BlazorUI 组件库——反馈与弹层 (1)
-
组件是前端的基础。组件库也是前端框架的核心中的重点。组件库中有一个重要的板块:反馈与弹层!反馈与弹层在组件形态上,与Button、Input类等嵌入界面的组件有所不同,通常以层的形式出现。本篇文章...
- 怎样创建一个Xcode插件(xcode如何新建一个main.c)
-
译者:@yohunl译者注:原文使用的是xcode6.3.2,我翻译的时候,使用的是xcode7.2.1,经过验证,本部分中说的依然是有效的.在文中你可以学习到一系列的技能,非常值得一看.这些技能不单...
- 让SSL/TLS协议流行起来:深度解读SSL/TLS实现1
-
一前言SSL/TLS协议是网络安全通信的重要基石,本系列将简单介绍SSL/TLS协议,主要关注SSL/TLS协议的安全性,特别是SSL规范的正确实现。本系列的文章大体分为3个部分:SSL/TLS协...
- 社交软件开发6-客户端开发-ios端开发验证登陆部分
-
欢迎订阅我的头条号:一点热上一节说到,Android客户端的开发,主要是编写了,如何使用Androidstudio如何创建一个Android项目,已经使用gradle来加载第三方库,并且使用了异步...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- 利用navicat将postgresql转为mysql
- Navicat的详细教程「偷偷收藏」(navicatlite)
- Linux系统安装SQL Server数据库(linux安装数据库命令)
- Navicat推出免费数据库管理软件Premium Lite
- Docker安装部署Oracle/Sql Server
- Docker安装MS SQL Server并使用Navicat远程连接
- Web性能的计算方式与优化方案(二)
- 网络入侵检测系统之Suricata(十四)——匹配流程
- 使用deepseek写一个图片转换代码(deepnode处理图片)
- AI助力AUTOCAD,生成LSP插件(ai里面cad插件怎么使用)
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- node卸载 (33)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- exceptionininitializererror (33)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)