使用Tensorflow开发一维生成对抗网络
ztj100 2024-11-21 00:30 19 浏览 0 评论
生成式对抗网络是一种用于训练生成器模型的深度学习体系结构。GAN由两个模型组成,一个称为生成器(Generator),另一个称为判别器(Discriminator)。顾名思义,生成器生成新样本,判别器负责对生成的样本进行真伪分类。
GAN实际如何运作的?
判别器模型的性能用于更新生成器和判别器本身的网络权重。生成器实际上从未看到过数据,而是根据判别器的性能不断地进行调整,更具体地说,是根据从判别器传回来的误差梯度进行调整。生成器逐渐学会通过产生与真实样本完全相同的样本来欺骗判别器。
在这篇文章中,我们将选择一个简单的一维函数来直观地理解GAN。本文分为5个部分:
- 选择一个一维函数
- 实现判别器模型
- 实现生成器模型
- 训练GAN模型
- 性能评估
1.一维函数
我们需要选择一个一维函数来制作模型。一维函数的形式为
y = f(x),其中x是输入,y是对应的输出
为简单起见,我将使用函数y = x2。您可以自由选择任何函数。我们将保持输入在-0.5和+0.5之间。下面给出了一个计算输入的简单函数 :
# Importing Required Packages
import numpy as np
from numpy.random import rand,randn
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
# n : number of samples
# type : real / fake
def get_real_samples(n):
X1 = rand(n)-0.5
X2=X1*X1
X1=X1.reshape(n,1)
X2=X2.reshape(n,1)
X=np.hstack((X1,X2))
y=np.ones((n,1))
return X,y
该函数简单地接受N个随机值,并将每个值减去0.5,以便将输入范围保持在-0.5和+0.5之间。当为real时y=1,当为fake时y=0。。
2.判别器模型
判别器只是一个简单的分类模型,它可以预测样本是real还是fake。判别器将两个实数值的样本作为输入,并输出样本是real还是fake。我们处理的问题非常简单,所以我们不需要非常复杂的神经网络,我们将只采用一个隐藏层,其中有25个节点。您可以自由地试验节点数或层数,以提高生成器的准确性。我们将对隐藏层使用ReLu激活,对输出层使用sigmoid激活。Python实现如下:
# Make discriminator model (classifies real/fake inputs)
def disc_model(input_dim=2):
model=Sequential()
#hidden layer
model.add(Dense(25,activation='relu',input_dim=input_dim))
#outputlayer
model.add(Dense(1,activation='sigmoid'))
#compile model
model.compile(loss='binary_crossentropy',optimizer='adam',metrics=['accuracy'])
return model
3.生成器模型
对于生成器,我们将噪声输入提供给生成器,此噪声输入也称为潜在变量。
潜在变量是潜在空间中的隐藏变量或未观察到的变量,潜在空间是这些变量的多维空间。
直到我们的生成器受到训练并赋予这些点意义,该潜在空间才有意义,这些点被映射到判别器的输入。我们将定义一个3维的潜在空间(可以更改维数),并实验生成器的行为和准确度如何变化。我们将对潜在空间中的每个变量使用高斯分布。生成器使用一个隐藏层,该隐藏层将由15个具有ReLu激活函数的神经元组成。输出层将由两个神经元组成,这两个神经元将连接到判别器层的输入。Python实现如下:
# This function generates fake points with help of generator
def get_fake_samples(gen,noise_dim,n):
#generating random points from latent space
x_input = noise_points(noise_dim,n)
# X is the output from the generator
X = gen.predict(x_input)
# Since these are fake samples so y=0
y=np.zeros((n,1))
return X,y
# generate noise points from latent space (latent variables)
def noise_points(noise_dim,n):
noise = randn(n*noise_dim)
noise=noise.reshape(n,noise_dim)
return noise
# Make generator model (generates fake inputs)
def gen_model(input_dim,output_dim=2):
model=Sequential()
model.add(Dense(15,activation='relu',input_dim=input_dim))
model.add(Dense(output_dim,activation='linear'))
return model
4.训练GAN模型
训练GAN模型的方法有很多,最简单的方法是创建一个新的模型,该模型由生成器和判别器两部分组成。我们只是在逻辑上封装了生成器和判别器网络。我们将把GAN模型作为一个整体进行训练,这样来自判别器的反向传播误差也会更新生成器的权重。如果判别器能够很好的进行分类,那么生成器的权重将更新地更多;如果判别器不能很好的进行分类,那么生成器的权重将更新得少一些。这样,在生成器和判别器之间就形成了一种对抗关系。Python代码如下:
# Logical model for connecting generator and discriminator
def gan_model(disc,gen):
disc.trainable=False
model=Sequential()
model.add(gen)
model.add(disc)
model.compile(loss='binary_crossentropy',optimizer='adam')
return model
判别器模型的可训练属性被设置为false,这样就可以仅对standalone模型进行训练。
现在我们只剩下对GAN模型进行整体训练了。我们将编写一个函数来做这个的事情。该函数将运行10000个epochs,每运行2000个epochs,它将评估判别器和生成器的性能。Python代码如下:
def train(g_model,d_model,gan_model,noise_dim,epochs=10000,batch_size=256,n_eval=2000):
half_batch=batch_size//2
for i in range(epochs):
x_real,y_real = get_real_samples(half_batch)
x_fake,y_fake = get_fake_samples(g_model,noise_dim,half_batch)
d_model.train_on_batch(x_real,y_real)
d_model.train_on_batch(x_fake,y_fake)
x_gan = noise_points(noise_dim,batch_size)
y = np.ones((batch_size,1))
gan_model.train_on_batch(x_gan,y)
if (i+1) % n_eval ==0:
show_performance(i+1,g_model,d_model,noise_dim)
5.评估性能
在每隔一定的epochs之后,我们将调用show_performance函数,该函数将从生成器中获取真实样本和虚假样本并预测结果。我们还将在散点图上绘制结果,以便我们可以查看GAN的性能。Python实现如下:
# Model evaluation function
def show_performance(epoch,g_model,d_model,noise_dim,n=100):
x_real,y_real = get_real_samples(n)
_,real_acc = d_model.evaluate(x_real,y_real,verbose=0)
x_fake,y_fake = get_fake_samples(g_model,noise_dim,n)
_,fake_acc = d_model.evaluate(x_fake,y_fake,verbose=0)
print(epoch,real_acc,fake_acc)
plt.figure(figsize=(20,10))
plt.scatter(x_real[:,0],x_real[:,1],color='red')
plt.scatter(x_fake[:,0],x_fake[:,1],color='blue')
noise_dim=5
gen = gen_model(noise_dim)
disc = disc_model()
gan = gan_model(disc,gen)
train(gen,disc,gan,noise_dim)
在epoch = 2000之后,我们得到了散点图如下,您的图可能会有所不同。
红点表示real点,蓝点表示生成器生成的点。我们可以看到,蓝点已开始呈y =x2的形状。
如果我们继续进行10000个epochs,您将得到类似下面的图像。您可以尝试使用更多个epochs(例如15000或20000个epochs)来获得更好的准确性。
现在我们可以看到,我们已经从生成器中得到了一个更确定的样本,我们可以说生成器已经学习并拟合了这个函数。也就是说,仅仅通过误差梯度,生成器就学会了这个函数。
相关推荐
- sharding-jdbc实现`分库分表`与`读写分离`
-
一、前言本文将基于以下环境整合...
- 三分钟了解mysql中主键、外键、非空、唯一、默认约束是什么
-
在数据库中,数据表是数据库中最重要、最基本的操作对象,是数据存储的基本单位。数据表被定义为列的集合,数据在表中是按照行和列的格式来存储的。每一行代表一条唯一的记录,每一列代表记录中的一个域。...
- MySQL8行级锁_mysql如何加行级锁
-
MySQL8行级锁版本:8.0.34基本概念...
- mysql使用小技巧_mysql使用入门
-
1、MySQL中有许多很实用的函数,好好利用它们可以省去很多时间:group_concat()将取到的值用逗号连接,可以这么用:selectgroup_concat(distinctid)fr...
- MySQL/MariaDB中如何支持全部的Unicode?
-
永远不要在MySQL中使用utf8,并且始终使用utf8mb4。utf8mb4介绍MySQL/MariaDB中,utf8字符集并不是对Unicode的真正实现,即不是真正的UTF-8编码,因...
- 聊聊 MySQL Server 可执行注释,你懂了吗?
-
前言MySQLServer当前支持如下3种注释风格:...
- MySQL系列-源码编译安装(v5.7.34)
-
一、系统环境要求...
- MySQL的锁就锁住我啦!与腾讯大佬的技术交谈,是我小看它了
-
对酒当歌,人生几何!朝朝暮暮,唯有己脱。苦苦寻觅找工作之间,殊不知今日之事乃我心之痛,难道是我不配拥有工作嘛。自面试后他所谓的等待都过去一段时日,可惜在下京东上的小金库都要见低啦。每每想到不由心中一...
- MySQL字符问题_mysql中字符串的位置
-
中文写入乱码问题:我输入的中文编码是urf8的,建的库是urf8的,但是插入mysql总是乱码,一堆"???????????????????????"我用的是ibatis,终于找到原因了,我是这么解决...
- 深圳尚学堂:mysql基本sql语句大全(三)
-
数据开发-经典1.按姓氏笔画排序:Select*FromTableNameOrderByCustomerNameCollateChinese_PRC_Stroke_ci_as//从少...
- MySQL进行行级锁的?一会next-key锁,一会间隙锁,一会记录锁?
-
大家好,是不是很多人都对MySQL加行级锁的规则搞的迷迷糊糊,一会是next-key锁,一会是间隙锁,一会又是记录锁。坦白说,确实还挺复杂的,但是好在我找点了点规律,也知道如何如何用命令分析加...
- 一文讲清怎么利用Python Django实现Excel数据表的导入导出功能
-
摘要:Python作为一门简单易学且功能强大的编程语言,广受程序员、数据分析师和AI工程师的青睐。本文系统讲解了如何使用Python的Django框架结合openpyxl库实现Excel...
- 用DataX实现两个MySQL实例间的数据同步
-
DataXDataX使用Java实现。如果可以实现数据库实例之间准实时的...
- MySQL数据库知识_mysql数据库基础知识
-
MySQL是一种关系型数据库管理系统;那废话不多说,直接上自己以前学习整理文档:查看数据库命令:(1).查看存储过程状态:showprocedurestatus;(2).显示系统变量:show...
- 如何为MySQL中的JSON字段设置索引
-
背景MySQL在2015年中发布的5.7.8版本中首次引入了JSON数据类型。自此,它成了一种逃离严格列定义的方式,可以存储各种形状和大小的JSON文档,例如审计日志、配置信息、第三方数据包、用户自定...
你 发表评论:
欢迎- 一周热门
-
-
MySQL中这14个小玩意,让人眼前一亮!
-
旗舰机新标杆 OPPO Find X2系列正式发布 售价5499元起
-
【VueTorrent】一款吊炸天的qBittorrent主题,人人都可用
-
面试官:使用int类型做加减操作,是线程安全吗
-
C++编程知识:ToString()字符串转换你用正确了吗?
-
【Spring Boot】WebSocket 的 6 种集成方式
-
PyTorch 深度学习实战(26):多目标强化学习Multi-Objective RL
-
pytorch中的 scatter_()函数使用和详解
-
与 Java 17 相比,Java 21 究竟有多快?
-
基于TensorRT_LLM的大模型推理加速与OpenAI兼容服务优化
-
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)