你还弄不懂的傅里叶变换,神经网络只用了30多行代码就学会了
ztj100 2024-10-28 21:15 26 浏览 0 评论
明敏 发自 凹非寺
量子位 报道 | 公众号 QbitAI
在我们的生活中,大到天体观测、小到MP3播放器上的频谱,没有傅里叶变换都无法实现。
通俗来讲,离散傅里叶变换(DFT)就是把一串复杂波形中分成不同频率成分。
比如声音,如果用声波记录仪显示声音的话,其实生活中绝大部分声音都是非常复杂、甚至杂乱无章的。
而通过傅里叶变换,就能把这些杂乱的声波转化为正弦波,也就是我们平常看到的音乐频谱图的样子。
不过在实际计算中,这个过程其实非常复杂。
如果把声波视作一个连续函数,它可以唯一表示为一堆三角函数相叠加。不过在叠加过程中,每个三角函数的加权系数不同,有的要加高一些、有的要压低一些,有的甚至不加。
傅里叶变换要找到这些三角函数以及它们各自的权重。
这不就巧了,这种找啊找的过程,像极了神经网络。
神经网络的本质其实就是逼近一个函数。
那岂不是可以用训练神经网络的方式来搞定傅里叶变换?
这还真的可行,并且最近有人在网上发布了自己训练的过程和结果。
DFT=神经网络
该怎么训练神经网络呢?这位网友给出的思路是这样的:
首先要把离散傅里叶变换(DFT)看作是一个人工神经网络,这是一个单层网络,没有bias、没有激活函数,并且对于权重有特定的值。它输出节点的数量等于傅里叶变换计算后频率的数量。
具体方法如下:
这是一个DFT:
- k表示每N个样本的循环次数;
- N表示信号的长度;
- 表示信号在样本n处的值。
一个信号可以表示为所有正弦信号的和。
yk是一个复值,它给出了信号x中频率为k的正弦信号的信息;从yk我们可以计算正弦的振幅和相位。
换成矩阵式,它就变成了这样:
这里给出了特定值k的傅里叶值。
不过通常情况下,我们要计算全频谱,即k从[0,1,…N-1]的值,这可以用一个矩阵来表示(k按列递增,n按行递增):
简化后得到:
看到这里应该还很熟悉,因为它是一个没有bias和激活函数的神经网络层。
指数矩阵包含权值,可以称之为复合傅里叶权值(Complex Fourier weights),通常情况下我们并不知道神经网络的权重,不过在这里可以。
- 不用复数
通常我们也不会在神经网络中使用复数,为了适应这种情况,就需要把矩阵的大小翻倍,使其左边部分包含实数,右边部分包含虚数。
将
带入DFT,可以得到:
然后用实部(cos形式)来表示矩阵的左半部分,用虚部(sin形式)来表示矩阵的右半部分:
简化后可以得到:
将
称为傅里叶权重;
需要注意的是,y^和y实际上包含相同的信息,但是y^
不使用复数,所以它的长度是y的两倍。
换句话说,我们可以用
或
表示振幅和相位,但是我们通常会使用
现在,就可以将傅里叶层加到网络中了。
用傅里叶权重计算傅里叶变换
现在就可以用神经网络来实现
,并用快速傅里叶变换(FFT)检查它是否正确。
import matplotlib.pyplot as plt
y_real = y[:, :signal_length]
y_imag = y[:, signal_length:]
tvals = np.arange(signal_length).reshape([-1, 1])
freqs = np.arange(signal_length).reshape([1, -1])
arg_vals = 2 * np.pi * tvals * freqs / signal_length
sinusoids = (y_real * np.cos(arg_vals) - y_imag * np.sin(arg_vals)) / signal_length
reconstructed_signal = np.sum(sinusoids, axis=1)
print('rmse:', np.sqrt(np.mean((x - reconstructed_signal)**2)))
plt.subplot(2, 1, 1)
plt.plot(x[0,:])
plt.title('Original signal')
plt.subplot(2, 1, 2)
plt.plot(reconstructed_signal)
plt.title('Signal reconstructed from sinusoids after DFT')
plt.tight_layout()
plt.show()
rmse: 2.3243522568191728e-15
得到的这个微小误差值可以证明,计算的结果是我们想要的。
- 另一种方法是重构信号:
import matplotlib.pyplot as plt
y_real = y[:, :signal_length]
y_imag = y[:, signal_length:]
tvals = np.arange(signal_length).reshape([-1, 1])
freqs = np.arange(signal_length).reshape([1, -1])
arg_vals = 2 * np.pi * tvals * freqs / signal_length
sinusoids = (y_real * np.cos(arg_vals) - y_imag * np.sin(arg_vals)) / signal_length
reconstructed_signal = np.sum(sinusoids, axis=1)
print('rmse:', np.sqrt(np.mean((x - reconstructed_signal)**2)))
plt.subplot(2, 1, 1)
plt.plot(x[0,:])
plt.title('Original signal')
plt.subplot(2, 1, 2)
plt.plot(reconstructed_signal)
plt.title('Signal reconstructed from sinusoids after DFT')
plt.tight_layout()
plt.show()
rmse: 2.3243522568191728e-15
最后可以看到,DFT后从正弦信号重建的信号和原始信号能够很好地重合。
通过梯度下降学习傅里叶变换
现在就到了让神经网络真正来学习的部分,这一步就不需要向之前那样预先计算权重值了。
首先,要用FFT来训练神经网络学习离散傅里叶变换:
import tensorflow as tf
signal_length = 32
# Initialise weight vector to train:
W_learned = tf.Variable(np.random.random([signal_length, 2 * signal_length]) - 0.5)
# Expected weights, for comparison:
W_expected = create_fourier_weights(signal_length)
losses = []
rmses = []
for i in range(1000):
# Generate a random signal each iteration:
x = np.random.random([1, signal_length]) - 0.5
# Compute the expected result using the FFT:
fft = np.fft.fft(x)
y_true = np.hstack([fft.real, fft.imag])
with tf.GradientTape() as tape:
y_pred = tf.matmul(x, W_learned)
loss = tf.reduce_sum(tf.square(y_pred - y_true))
# Train weights, via gradient descent:
W_gradient = tape.gradient(loss, W_learned)
W_learned = tf.Variable(W_learned - 0.1 * W_gradient)
losses.append(loss)
rmses.append(np.sqrt(np.mean((W_learned - W_expected)**2)))
Final loss value 1.6738563548424711e-09
Final weights' rmse value 3.1525832404710523e-06
得出结果如上,这证实了神经网络确实能够学习离散傅里叶变换。
训练网络学习DFT
除了用快速傅里叶变化的方法,还可以通过网络来重建输入信号来学习DFT。(类似于autoencoders自编码器)。
自编码器(autoencoder, AE)是一类在半监督学习和非监督学习中使用的人工神经网络(Artificial Neural Networks, ANNs),其功能是通过将输入信息作为学习目标,对输入信息进行表征学习(representation learning)。
W_learned = tf.Variable(np.random.random([signal_length, 2 * signal_length]) - 0.5)
tvals = np.arange(signal_length).reshape([-1, 1])
freqs = np.arange(signal_length).reshape([1, -1])
arg_vals = 2 * np.pi * tvals * freqs / signal_length
cos_vals = tf.cos(arg_vals) / signal_length
sin_vals = tf.sin(arg_vals) / signal_length
losses = []
rmses = []
for i in range(10000):
x = np.random.random([1, signal_length]) - 0.5
with tf.GradientTape() as tape:
y_pred = tf.matmul(x, W_learned)
y_real = y_pred[:, 0:signal_length]
y_imag = y_pred[:, signal_length:]
sinusoids = y_real * cos_vals - y_imag * sin_vals
reconstructed_signal = tf.reduce_sum(sinusoids, axis=1)
loss = tf.reduce_sum(tf.square(x - reconstructed_signal))
W_gradient = tape.gradient(loss, W_learned)
W_learned = tf.Variable(W_learned - 0.5 * W_gradient)
losses.append(loss)
rmses.append(np.sqrt(np.mean((W_learned - W_expected)**2)))
Final loss value 4.161919455121241e-22
Final weights' rmse value 0.20243339269590094
作者用这一模型进行了很多测试,最后得到的权重不像上面的例子中那样接近傅里叶权值,但是可以看到重建的信号是一致的。
换成输入振幅和相位试试看呢。
W_learned = tf.Variable(np.random.random([signal_length, 2 * signal_length]) - 0.5)
losses = []
rmses = []
for i in range(10000):
x = np.random.random([1, signal_length]) - .5
with tf.GradientTape() as tape:
y_pred = tf.matmul(x, W_learned)
y_real = y_pred[:, 0:signal_length]
y_imag = y_pred[:, signal_length:]
amplitudes = tf.sqrt(y_real**2 + y_imag**2) / signal_length
phases = tf.atan2(y_imag, y_real)
sinusoids = amplitudes * tf.cos(arg_vals + phases)
reconstructed_signal = tf.reduce_sum(sinusoids, axis=1)
loss = tf.reduce_sum(tf.square(x - reconstructed_signal))
W_gradient = tape.gradient(loss, W_learned)
W_learned = tf.Variable(W_learned - 0.5 * W_gradient)
losses.append(loss)
rmses.append(np.sqrt(np.mean((W_learned - W_expected)**2)))
Final loss value 2.2379359316633115e-21
Final weights' rmse value 0.2080118219691059
可以看到,重建信号再次一致;
不过,和此前一样,输入振幅和相位最终得到的权值也不完全等同于傅里叶权值(但非常接近)。
由此可以得出结论,虽然最后得到的权重还不是最准确的,但是也能够获得局部的最优解。
这样一来,神经网络就学会了傅里叶变换!
- 值得一提的是,这个方法目前还有疑问存在:
首先,它没有解释计算出的权值和真正的傅里叶权值相差多少;
而且,也没有说明将傅里叶层放到模型中能带来哪些益处。
原文链接:
https://sidsite.com/posts/fourier-nets/
— 完 —
量子位 QbitAI · 头条号签约
关注我们,第一时间获知前沿科技动态
相关推荐
- Jquery 详细用法
-
1、jQuery介绍(1)jQuery是什么?是一个js框架,其主要思想是利用jQuery提供的选择器查找要操作的节点,然后将找到的节点封装成一个jQuery对象。封装成jQuery对象的目的有...
- 前端开发79条知识点汇总
-
1.css禁用鼠标事件2.get/post的理解和他们之间的区别http超文本传输协议(HTTP)的设计目的是保证客户机与服务器之间的通信。HTTP的工作方式是客户机与服务器之间的请求-应答协议。...
- js基础面试题92-130道题目
-
92.说说你对作用域链的理解参考答案:作用域链的作用是保证执行环境里有权访问的变量和函数是有序的,作用域链的变量只能向上访问,变量访问到window对象即被终止,作用域链向下访问变量是不被允许的。...
- Web前端必备基础知识点,百万网友:牛逼
-
1、Web中的常见攻击方式1.SQL注入------常见的安全性问题。解决方案:前端页面需要校验用户的输入数据(限制用户输入的类型、范围、格式、长度),不能只靠后端去校验用户数据。一来可以提高后端处理...
- 事件——《JS高级程序设计》
-
一、事件流1.事件流描述的是从页面中接收事件的顺序2.事件冒泡(eventbubble):事件从开始时由最具体的元素(就是嵌套最深的那个节点)开始,逐级向上传播到较为不具体的节点(就是Docu...
- 前端开发中79条不可忽视的知识点汇总
-
过往一些不足的地方,通过博客,好好总结一下。1.css禁用鼠标事件...
- Chrome 开发工具之Network
-
经常会听到比如"为什么我的js代码没执行啊?","我明明发送了请求,为什么反应?","我这个网站怎么加载的这么慢?"这类的问题,那么问题既然存在,就需要去解决它,需要解决它,首先我们得找对导致问题的原...
- 轻量级 React.js 虚拟美化滚动条组件RScroll
-
前几天有给大家分享一个Vue自定义滚动条组件VScroll。今天再分享一个最新开发的ReactPC端模拟滚动条组件RScroll。...
- 一文解读JavaScript事件对象和表单对象
-
前言相信做网站对JavaScript再熟悉不过了,它是一门脚本语言,不同于Python的是,它是一门浏览器脚本语言,而Python则是服务器脚本语言,我们不光要会Python,还要会JavaScrip...
- Python函数参数黑科技:*args与**kwargs深度解析
-
90%的Python程序员不知道,可变参数设计竟能决定函数的灵活性和扩展性!掌握这些技巧,让你的函数适应任何场景!一、函数参数设计的三大进阶技巧...
- 深入理解Python3密码学:详解PyCrypto库加密、解密与数字签名
-
在现代计算领域,信息安全逐渐成为焦点话题。密码学,作为信息保护的关键技术之一,允许我们加密(保密)和解密(解密)数据。...
- 阿里Nacos惊爆安全漏洞,火速升级!(附修复建议)
-
前言好,我是threedr3am,我发现nacos最新版本1.4.1对于User-Agent绕过安全漏洞的serverIdentitykey-value修复机制,依然存在绕过问题,在nacos开启了...
- Python模块:zoneinfo时区支持详解
-
一、知识导图二、知识讲解(一)zoneinfo模块概述...
- Golang开发的一些注意事项(一)
-
1.channel关闭后读的问题当channel关闭之后再去读取它,虽然不会引发panic,但会直接得到零值,而且ok的值为false。packagemainimport"...
- Python鼠标与键盘自动化指南:从入门到进阶——键盘篇
-
`pynput`是一个用于控制和监控鼠标和键盘的Python库...
你 发表评论:
欢迎- 一周热门
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)