深度信号处理:利用卷积神经网络测量距离
ztj100 2025-01-01 23:49 48 浏览 0 评论
在信号处理中,有时需要测量信号某些特征(例如峰)之间的水平距离。 一个很好的例子就是解释心电图(ECG),这在很大程度上取决于测量距离。 我们将考虑下图中只有两个峰的平滑信号的一个样例。
解决这个问题很简单,可以通过找到峰值,然后减去它们的X坐标来测量它们之间的水平距离来解决。这可以通过使用可用的工具和库有效地完成。然而,我们的目标是训练一个神经网络来预测两个峰之间的距离。一旦我们证明了神经网络可以处理这一任务,我们就可以在更复杂的端到端学习任务中重用相同的架构,而测量距离只是学习更复杂关系的一种手段。这源于深度学习的理念,即我们应该尝试让神经网络学习特征,而不是让工程师手工编码特征并希望这些特征是最相关的。如果我们能证明神经网络可以学习距离特征,我们就可以在更复杂的网络中使用它,在这些网络中,最终结果将取决于距离以外的许多其他因素。这些任务的典型例子是解释心电图或天文数据。
生成数据
在我们的实验中,我们将使用生成训练和测试数据的生成器函数生成如上图所示的信号。
def get_signal_generator(batch_size, n_points, mean_distance, std_distance,
mean_width, std_width):
def generate_one():
first = np.random.uniform(0, n_points /4)
second = first + np.random.normal(mean_distance * n_points, std_distance * n_points)
if second > n_points * 0.95:
second = n_points * 0.95
distance = second - first
first_width = max(np.random.normal(mean_width * n_points, std_width * n_points), n_points * std_width)
second_width = max(np.random.normal(mean_width * n_points, std_width * n_points), n_points * std_width)
data_range = np.arange(n_points)
signal = norm.pdf(data_range, first, first_width) + \
norm.pdf(data_range, second, second_width)
return signal, distance
def generate():
sanity=100000
for _ in range(sanity):
all_data = [generate_one() for _ in range(batch_size)]
yield np.vstack([element[0] for element in all_data]), np.vstack([element[1] for element in all_data])
return generate()
这是一个python生成器函数,意味着它使用yield关键字而不是return。 每次在生成器上调用next()函数时,都会产生下一个结果。此功能生成正好具有两个峰值的信号。所有信号的长度完全相同。第一个峰值的位置均匀分布在信号的第一象限中,但是第二个峰值的位置呈正态分布,但是我们还要确保它不会超出范围。峰的宽度也呈正态分布。我们分批返回峰,这对神经网络的训练和评估很有用。
请注意,此生成器实际上会生成无限量的数据! 因此,对于我们的示例可以尝试实现尽可能高的精度。
找到峰值
现在我们有了生成器函数,我们可以使用标准信号处理库来找到峰值之间的距离。我们将使用scipy库和函数find_peaks()来查找峰值。我们使用R2评分来评估模型。如下图所示,我们得到了近乎完美的分数,预测误差主要是由于数字舍入误差造成的。
def predict_distance(batch):
def compute_distance(row):
peaks = find_peaks(row)[0]
if len(peaks) < 2:
return 0
return abs(peaks[1] - peaks[0])
return np.vstack([compute_distance(batch[i,:]) for i in range(batch.shape[0])])
np.random.seed(2128506)
data_generator = get_signal_generator(500, 1000, 0.7, 0.1, 0.03, 0.01)
batch_x, batch_y = next(data_generator)
predictions = predict_distance(batch_x)
print('Baseline performance: ', r2_score(batch_y, predictions))
Baseline performance: 0.9999812121197582
使用CNN来测量距离
在设计神经网络时,想象一个人类操作员会做什么通常是很有用的。在我们的例子中,操作是测量,测量的工具是一把尺子。在我们的例子中,我们使用一个一维卷积层来模拟标尺,并将内核大小设置为信号的最大长度。这样做原因是,如果层的值从0,1,2,3,4,…当乘以信号,它将准确地给我们的位置的峰值。我们使用了两个滤波器来测量两个峰值的位置,然后添加两个全连接层,让神经网络学习如何获取这两个测量值之间的差异。
我们使用Tensorflow和Keras实现神经网络。 请注意,由于Conv1D需要三维张量,因此我们添加了一个Reshape图层,该图层添加了尺寸1的第三个尺寸。批量尺寸是隐式假定的。 对于卷积层,我们不使用任何激活函数,因为我们希望该层的行为类似于标尺。 注意,我们不使用任何下采样机制(最大池化或平均池化)。 我相信这些不是必需的,实际上会降低精度,因为它们会使测量变得不那么精确。 在将数据发送到Dense层之前,我们添加Flatten层以将尺寸(批处理尺寸除外)折叠为单个尺寸,因为这是Dense层所期望的。
model = Sequential([
Input(shape=(1000,)),
Reshape((-1,1)),
Conv1D(filters=2, kernel_size=1000, activation=None),
Flatten(),
Dense(16, activation='relu'),
Dense(16, activation='relu'),
Dense(1)
])
model.compile(optimizer=Adam(lr=0.7), loss='mse')
model.summary()
结果如下:
Model: "sequential_26"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
reshape_26 (Reshape) (None, 1000, 1) 0
_________________________________________________________________
conv1d_33 (Conv1D) (None, 1, 2) 2002
_________________________________________________________________
flatten_25 (Flatten) (None, 2) 0
_________________________________________________________________
dense_63 (Dense) (None, 16) 48
_________________________________________________________________
dense_64 (Dense) (None, 16) 272
_________________________________________________________________
dense_65 (Dense) (None, 1) 17
=================================================================
Total params: 2,339
Trainable params: 2,339
Non-trainable params: 0
_________________________________________________________________
正如我们所看到的,这个模型只有2339个参数,所以它是一个非常简单的模型。我们在50个伦茨内训练模型,但是我们也添加了早期停止回调,以便在模型停止改进时停止执行。我们添加另一个回调函数TerminateOnNaN,如果梯度或损失变成NaN,它将停止训练过程。我们将一个生成器函数传递给fit()方法。这是现在推荐的将数据传递给模型的方法,特别是当数据量很大时。在我们的例子中,生成器不断地生成随机的例子,实际上是无限的例子!由于我们的生成器函数没有epoch的概念,我们需要定义一个参数stepsperepoch,否则模型将认为所有批次都属于第一个epoch,并且训练将永远不会结束。
np.random.seed(2128506)
tf.random.set_seed(2128506)
data_generator = get_signal_generator(500, 1000, 0.7, 0.1, 0.03, 0.01)
model.fit(data_generator, epochs=50, steps_per_epoch=100, callbacks=[EarlyStopping(monitor='loss'), TerminateOnNaN()])
训练过程如下:
Epoch 1/50
100/100 [==============================] - 16s 157ms/step - loss: 55010.6915
Epoch 2/50
100/100 [==============================] - 16s 161ms/step - loss: 186.3050
Epoch 3/50
100/100 [==============================] - 16s 160ms/step - loss: 89.9977
Epoch 4/50
100/100 [==============================] - 16s 159ms/step - loss: 229.8199
下面我们看一下结果:
time.sleep(1)
batch_x, batch_y = next(data_generator)
predictions = model.predict(batch_x)
print('R^2 score: ', r2_score(batch_y, np.squeeze(predictions)))
R^2 score: 0.996036173273703
在训练一个模型后,我们看到测试集的分数确实令人印象深刻。虽然我们承认进一步改进算法是可能的,但我们得到的结果证明了我们的简单方法确实有效。
总结
在设计一个神经网络时,想象人类的感知和认知是如何工作的往往是成功的关键。我们如何产生高层次的特征和概念通常能够指导我们进行神经网络的架构搭建。这种方法的示例之一是注意力机制,注意力机制是根据我们根据阅读的文本进行归纳总结时的注意力进行建模的。在这个问题中,代表了人类活动指导神经网络构建的另一个示例。 尽管使用CNN来测量距离(与Attention机制一样)本身并没有用,但是只要我们相信水平距离起作用,就可以将此结构并入更大的神经网络来解决更复杂的任务。
本文代码:github/mlarionov/deep-signal-example/blob/main/two-peaks.ipynb
作者:Michael Larionov, PhD
deephub翻译组
相关推荐
- sharding-jdbc实现`分库分表`与`读写分离`
-
一、前言本文将基于以下环境整合...
- 三分钟了解mysql中主键、外键、非空、唯一、默认约束是什么
-
在数据库中,数据表是数据库中最重要、最基本的操作对象,是数据存储的基本单位。数据表被定义为列的集合,数据在表中是按照行和列的格式来存储的。每一行代表一条唯一的记录,每一列代表记录中的一个域。...
- MySQL8行级锁_mysql如何加行级锁
-
MySQL8行级锁版本:8.0.34基本概念...
- mysql使用小技巧_mysql使用入门
-
1、MySQL中有许多很实用的函数,好好利用它们可以省去很多时间:group_concat()将取到的值用逗号连接,可以这么用:selectgroup_concat(distinctid)fr...
- MySQL/MariaDB中如何支持全部的Unicode?
-
永远不要在MySQL中使用utf8,并且始终使用utf8mb4。utf8mb4介绍MySQL/MariaDB中,utf8字符集并不是对Unicode的真正实现,即不是真正的UTF-8编码,因...
- 聊聊 MySQL Server 可执行注释,你懂了吗?
-
前言MySQLServer当前支持如下3种注释风格:...
- MySQL系列-源码编译安装(v5.7.34)
-
一、系统环境要求...
- MySQL的锁就锁住我啦!与腾讯大佬的技术交谈,是我小看它了
-
对酒当歌,人生几何!朝朝暮暮,唯有己脱。苦苦寻觅找工作之间,殊不知今日之事乃我心之痛,难道是我不配拥有工作嘛。自面试后他所谓的等待都过去一段时日,可惜在下京东上的小金库都要见低啦。每每想到不由心中一...
- MySQL字符问题_mysql中字符串的位置
-
中文写入乱码问题:我输入的中文编码是urf8的,建的库是urf8的,但是插入mysql总是乱码,一堆"???????????????????????"我用的是ibatis,终于找到原因了,我是这么解决...
- 深圳尚学堂:mysql基本sql语句大全(三)
-
数据开发-经典1.按姓氏笔画排序:Select*FromTableNameOrderByCustomerNameCollateChinese_PRC_Stroke_ci_as//从少...
- MySQL进行行级锁的?一会next-key锁,一会间隙锁,一会记录锁?
-
大家好,是不是很多人都对MySQL加行级锁的规则搞的迷迷糊糊,一会是next-key锁,一会是间隙锁,一会又是记录锁。坦白说,确实还挺复杂的,但是好在我找点了点规律,也知道如何如何用命令分析加...
- 一文讲清怎么利用Python Django实现Excel数据表的导入导出功能
-
摘要:Python作为一门简单易学且功能强大的编程语言,广受程序员、数据分析师和AI工程师的青睐。本文系统讲解了如何使用Python的Django框架结合openpyxl库实现Excel...
- 用DataX实现两个MySQL实例间的数据同步
-
DataXDataX使用Java实现。如果可以实现数据库实例之间准实时的...
- MySQL数据库知识_mysql数据库基础知识
-
MySQL是一种关系型数据库管理系统;那废话不多说,直接上自己以前学习整理文档:查看数据库命令:(1).查看存储过程状态:showprocedurestatus;(2).显示系统变量:show...
- 如何为MySQL中的JSON字段设置索引
-
背景MySQL在2015年中发布的5.7.8版本中首次引入了JSON数据类型。自此,它成了一种逃离严格列定义的方式,可以存储各种形状和大小的JSON文档,例如审计日志、配置信息、第三方数据包、用户自定...
你 发表评论:
欢迎- 一周热门
-
-
MySQL中这14个小玩意,让人眼前一亮!
-
旗舰机新标杆 OPPO Find X2系列正式发布 售价5499元起
-
【VueTorrent】一款吊炸天的qBittorrent主题,人人都可用
-
面试官:使用int类型做加减操作,是线程安全吗
-
C++编程知识:ToString()字符串转换你用正确了吗?
-
【Spring Boot】WebSocket 的 6 种集成方式
-
PyTorch 深度学习实战(26):多目标强化学习Multi-Objective RL
-
pytorch中的 scatter_()函数使用和详解
-
与 Java 17 相比,Java 21 究竟有多快?
-
基于TensorRT_LLM的大模型推理加速与OpenAI兼容服务优化
-
- 最近发表
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)