使用Tensorflow开发一维生成对抗网络
ztj100 2024-11-21 00:30 15 浏览 0 评论
生成式对抗网络是一种用于训练生成器模型的深度学习体系结构。GAN由两个模型组成,一个称为生成器(Generator),另一个称为判别器(Discriminator)。顾名思义,生成器生成新样本,判别器负责对生成的样本进行真伪分类。
GAN实际如何运作的?
判别器模型的性能用于更新生成器和判别器本身的网络权重。生成器实际上从未看到过数据,而是根据判别器的性能不断地进行调整,更具体地说,是根据从判别器传回来的误差梯度进行调整。生成器逐渐学会通过产生与真实样本完全相同的样本来欺骗判别器。
在这篇文章中,我们将选择一个简单的一维函数来直观地理解GAN。本文分为5个部分:
- 选择一个一维函数
- 实现判别器模型
- 实现生成器模型
- 训练GAN模型
- 性能评估
1.一维函数
我们需要选择一个一维函数来制作模型。一维函数的形式为
y = f(x),其中x是输入,y是对应的输出
为简单起见,我将使用函数y = x2。您可以自由选择任何函数。我们将保持输入在-0.5和+0.5之间。下面给出了一个计算输入的简单函数 :
# Importing Required Packages
import numpy as np
from numpy.random import rand,randn
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
# n : number of samples
# type : real / fake
def get_real_samples(n):
X1 = rand(n)-0.5
X2=X1*X1
X1=X1.reshape(n,1)
X2=X2.reshape(n,1)
X=np.hstack((X1,X2))
y=np.ones((n,1))
return X,y
该函数简单地接受N个随机值,并将每个值减去0.5,以便将输入范围保持在-0.5和+0.5之间。当为real时y=1,当为fake时y=0。。
2.判别器模型
判别器只是一个简单的分类模型,它可以预测样本是real还是fake。判别器将两个实数值的样本作为输入,并输出样本是real还是fake。我们处理的问题非常简单,所以我们不需要非常复杂的神经网络,我们将只采用一个隐藏层,其中有25个节点。您可以自由地试验节点数或层数,以提高生成器的准确性。我们将对隐藏层使用ReLu激活,对输出层使用sigmoid激活。Python实现如下:
# Make discriminator model (classifies real/fake inputs)
def disc_model(input_dim=2):
model=Sequential()
#hidden layer
model.add(Dense(25,activation='relu',input_dim=input_dim))
#outputlayer
model.add(Dense(1,activation='sigmoid'))
#compile model
model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
return model
3.生成器模型
对于生成器,我们将噪声输入提供给生成器,此噪声输入也称为潜在变量。
潜在变量是潜在空间中的隐藏变量或未观察到的变量,潜在空间是这些变量的多维空间。
直到我们的生成器受到训练并赋予这些点意义,该潜在空间才有意义,这些点被映射到判别器的输入。我们将定义一个3维的潜在空间(可以更改维数),并实验生成器的行为和准确度如何变化。我们将对潜在空间中的每个变量使用高斯分布。生成器使用一个隐藏层,该隐藏层将由15个具有ReLu激活函数的神经元组成。输出层将由两个神经元组成,这两个神经元将连接到判别器层的输入。Python实现如下:
# This function generates fake points with help of generator
def get_fake_samples(gen,noise_dim,n):
#generating random points from latent space
x_input = noise_points(noise_dim,n)
# X is the output from the generator
X = gen.predict(x_input)
# Since these are fake samples so y=0
y=np.zeros((n,1))
return X,y
# generate noise points from latent space (latent variables)
def noise_points(noise_dim,n):
noise = randn(n*noise_dim)
noise=noise.reshape(n,noise_dim)
return noise
# Make generator model (generates fake inputs)
def gen_model(input_dim,output_dim=2):
model=Sequential()
model.add(Dense(15,activation='relu',input_dim=input_dim))
model.add(Dense(output_dim,activation='linear'))
return model
4.训练GAN模型
训练GAN模型的方法有很多,最简单的方法是创建一个新的模型,该模型由生成器和判别器两部分组成。我们只是在逻辑上封装了生成器和判别器网络。我们将把GAN模型作为一个整体进行训练,这样来自判别器的反向传播误差也会更新生成器的权重。如果判别器能够很好的进行分类,那么生成器的权重将更新地更多;如果判别器不能很好的进行分类,那么生成器的权重将更新得少一些。这样,在生成器和判别器之间就形成了一种对抗关系。Python代码如下:
# Logical model for connecting generator and discriminator
def gan_model(disc,gen):
disc.trainable=False
model=Sequential()
model.add(gen)
model.add(disc)
model.compile(loss='binary_crossentropy',optimizer='adam')
return model
判别器模型的可训练属性被设置为false,这样就可以仅对standalone模型进行训练。
现在我们只剩下对GAN模型进行整体训练了。我们将编写一个函数来做这个的事情。该函数将运行10000个epochs,每运行2000个epochs,它将评估判别器和生成器的性能。Python代码如下:
def train(g_model,d_model,gan_model,noise_dim,epochs=10000,batch_size=256,n_eval=2000):
half_batch=batch_size//2
for i in range(epochs):
x_real,y_real = get_real_samples(half_batch)
x_fake,y_fake = get_fake_samples(g_model,noise_dim,half_batch)
d_model.train_on_batch(x_real,y_real)
d_model.train_on_batch(x_fake,y_fake)
x_gan = noise_points(noise_dim,batch_size)
y = np.ones((batch_size,1))
gan_model.train_on_batch(x_gan,y)
if (i+1) % n_eval ==0:
show_performance(i+1,g_model,d_model,noise_dim)
5.评估性能
在每隔一定的epochs之后,我们将调用show_performance函数,该函数将从生成器中获取真实样本和虚假样本并预测结果。我们还将在散点图上绘制结果,以便我们可以查看GAN的性能。Python实现如下:
# Model evaluation function
def show_performance(epoch,g_model,d_model,noise_dim,n=100):
x_real,y_real = get_real_samples(n)
_,real_acc = d_model.evaluate(x_real,y_real,verbose=0)
x_fake,y_fake = get_fake_samples(g_model,noise_dim,n)
_,fake_acc = d_model.evaluate(x_fake,y_fake,verbose=0)
print(epoch,real_acc,fake_acc)
plt.figure(figsize=(20,10))
plt.scatter(x_real[:,0],x_real[:,1],color='red')
plt.scatter(x_fake[:,0],x_fake[:,1],color='blue')
noise_dim=5
gen = gen_model(noise_dim)
disc = disc_model()
gan = gan_model(disc,gen)
train(gen,disc,gan,noise_dim)
在epoch = 2000之后,我们得到了散点图如下,您的图可能会有所不同。
红点表示real点,蓝点表示生成器生成的点。我们可以看到,蓝点已开始呈y =x2的形状。
如果我们继续进行10000个epochs,您将得到类似下面的图像。您可以尝试使用更多个epochs(例如15000或20000个epochs)来获得更好的准确性。
现在我们可以看到,我们已经从生成器中得到了一个更确定的样本,我们可以说生成器已经学习并拟合了这个函数。也就是说,仅仅通过误差梯度,生成器就学会了这个函数。
相关推荐
- 30天学会Python编程:16. Python常用标准库使用教程
-
16.1collections模块16.1.1高级数据结构16.1.2示例...
- 强烈推荐!Python 这个宝藏库 re 正则匹配
-
Python的re模块(RegularExpression正则表达式)提供各种正则表达式的匹配操作。...
- Python爬虫中正则表达式的用法,只讲如何应用,不讲原理
-
Python爬虫:正则的用法(非原理)。大家好,这节课给大家讲正则的实际用法,不讲原理,通俗易懂的讲如何用正则抓取内容。·导入re库,这里是需要从html这段字符串中提取出中间的那几个文字。实例一个对...
- Python数据分析实战-正则提取文本的URL网址和邮箱(源码和效果)
-
实现功能:Python数据分析实战-利用正则表达式提取文本中的URL网址和邮箱...
- python爬虫教程之爬取当当网 Top 500 本五星好评书籍
-
我们使用requests和re来写一个爬虫作为一个爱看书的你(说的跟真的似的)怎么能发现好书呢?所以我们爬取当当网的前500本好五星评书籍怎么样?ok接下来就是学习python的正确姿...
- 深入理解re模块:Python中的正则表达式神器解析
-
在Python中,"re"是一个强大的模块,用于处理正则表达式(regularexpressions)。正则表达式是一种强大的文本模式匹配工具,用于在字符串中查找、替换或提取特定模式...
- 如何使用正则表达式和 Python 匹配不以模式开头的字符串
-
需要在Python中使用正则表达式来匹配不以给定模式开头的字符串吗?如果是这样,你可以使用下面的语法来查找所有的字符串,除了那些不以https开始的字符串。r"^(?!https).*&...
- 先Mark后用!8分钟读懂 Python 性能优化
-
从本文总结了Python开发时,遇到的性能优化问题的定位和解决。概述:性能优化的原则——优化需要优化的部分。性能优化的一般步骤:首先,让你的程序跑起来结果一切正常。然后,运行这个结果正常的代码,看看它...
- Python“三步”即可爬取,毋庸置疑
-
声明:本实例仅供学习,切忌遵守robots协议,请不要使用多线程等方式频繁访问网站。#第一步导入模块importreimportrequests#第二步获取你想爬取的网页地址,发送请求,获取网页内...
- 简单学Python——re库(正则表达式)2(split、findall、和sub)
-
1、split():分割字符串,返回列表语法:re.split('分隔符','目标字符串')例如:importrere.split(',','...
- Lavazza拉瓦萨再度牵手上海大师赛
-
阅读此文前,麻烦您点击一下“关注”,方便您进行讨论和分享。Lavazza拉瓦萨再度牵手上海大师赛标题:2024上海大师赛:网球与咖啡的浪漫邂逅在2024年的上海劳力士大师赛上,拉瓦萨咖啡再次成为官...
- ArkUI-X构建Android平台AAR及使用
-
本教程主要讲述如何利用ArkUI-XSDK完成AndroidAAR开发,实现基于ArkTS的声明式开发范式在android平台显示。包括:1.跨平台Library工程开发介绍...
- Deepseek写歌详细教程(怎样用deepseek写歌功能)
-
以下为结合DeepSeek及相关工具实现AI写歌的详细教程,涵盖作词、作曲、演唱全流程:一、核心流程三步法1.AI生成歌词-打开DeepSeek(网页/APP/API),使用结构化提示词生成歌词:...
- “AI说唱解说影视”走红,“零基础入行”靠谱吗?本报记者实测
-
“手里翻找冻鱼,精心的布局;老漠却不言语,脸上带笑意……”《狂飙》剧情被写成歌词,再配上“科目三”背景音乐的演唱,这段1分钟30秒的视频受到了无数网友的点赞。最近一段时间随着AI技术的发展,说唱解说影...
- AI音乐制作神器揭秘!3款工具让你秒变高手
-
在音乐创作的领域里,每个人都有一颗想要成为大师的心。但是面对复杂的乐理知识和繁复的制作过程,许多人的热情被一点点消磨。...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- 30天学会Python编程:16. Python常用标准库使用教程
- 强烈推荐!Python 这个宝藏库 re 正则匹配
- Python爬虫中正则表达式的用法,只讲如何应用,不讲原理
- Python数据分析实战-正则提取文本的URL网址和邮箱(源码和效果)
- python爬虫教程之爬取当当网 Top 500 本五星好评书籍
- 深入理解re模块:Python中的正则表达式神器解析
- 如何使用正则表达式和 Python 匹配不以模式开头的字符串
- 先Mark后用!8分钟读懂 Python 性能优化
- Python“三步”即可爬取,毋庸置疑
- 简单学Python——re库(正则表达式)2(split、findall、和sub)
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)