百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

【Python深度学习系列】网格搜索神经网络超参数:丢弃率dropout

ztj100 2024-12-01 07:02 17 浏览 0 评论

这是我的第271篇原创文章。

一、引言

在深度学习中,超参数是指在训练模型时需要手动设置的参数,它们通常不能通过训练数据自动学习得到。超参数的选择对于模型的性能至关重要,因此在进行深度学习实验时,超参数调优通常是一个重要的步骤。常见的超参数包括:

  • model.add()
    • neurons(隐含层神经元数量)
    • init_mode(初始权重方法)
    • activation(激活函数)
    • dropout(丢弃率)
  • model.compile()
    • loss(损失函数)
    • optimizer(优化器)
      • learning rate(学习率)
      • momentum(动量)
      • weight decay(权重衰减系数)
  • model.fit()
    • batch size(批量大小)
    • epochs(迭代次数)

一般来说,可以通过手动调优、网格搜索(Grid Search)、随机搜索(Random Search)、自动调参算法方式进行超参数调优。dropout是指在网络的训练过程中,对于神经网络单元,按照一定的概率将其暂时从网络中丢弃。注意是暂时,对于随机梯度下降来说,由于是随机丢弃,故而每一个mini-batch都在训练不同的网络。考虑调节 dropout 为了正则化,为了限制过拟合以及提供模型的泛化能力。本文采用网格搜索选择丢弃率dropout。

二、实现过程

2.1 准备数据

dataset:

dataset = pd.read_csv("data.csv", header=None)
dataset = pd.DataFrame(dataset)
print(dataset)

2.2 数据划分

# 切分数据为输入 X 和输出 Y
X = dataset.iloc[:,0:8]
Y = dataset.iloc[:,8]
# 为了复现,设置随机种子
seed = 7
np.random.seed(seed)
random.set_seed(seed)

2.3 创建模型

需要定义个网格的架构函数create_model,create_model里面的参数要在KerasClassifier这个对象里面存在而且参数名要一致。

def create_model(dropout_rate):
    # 创建模型
    model = Sequential()
    model.add(Dense(50, input_shape=(8, ), kernel_initializer='uniform', activation='relu'))
    model.add(Dropout(dropout_rate))
    model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))
    # 编译模型
    model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model
model = KerasClassifier(model=create_model, epochs=100, batch_size=80, verbose=0, dropout_rate=0.2)

这里使用了scikeras库的KerasClassifier类来定义一个分类器,这里由于KerasClassifier没有定义初始化权重的参数,需要自定义一个表示丢弃率的参数dropout_rate,并赋默认值为0.2。

2.4 定义网格搜索参数

param_grid = {'dropout_rate': [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]}

param_grid是一个字典,key是超参数名称,这里的名称必须要在KerasClassifier这个对象里面存在而且参数名要一致。value是key可取的值,也就是要尝试的方案。

2.5 进行参数搜索

from sklearn.model_selection import GridSearchCV
grid = GridSearchCV(estimator=model,  param_grid=param_grid)
grid_result = grid.fit(X, Y)

使用sklearn里面的GridSearchCV类进行参数搜索,传入模型和网格参数。

2.6 总结搜索结果

print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
means = grid_result.cv_results_['mean_test_score']
stds = grid_result.cv_results_['std_test_score']
params = grid_result.cv_results_['params']
for mean, stdev, param in zip(means, stds, params):
    print("%f (%f) with: %r" % (mean, stdev, param))

结果:

经过网格搜索,丢弃率的最优选择是0.2。

作者简介: 读研期间发表6篇SCI数据算法相关论文,目前在某研究院从事数据算法相关研究工作,结合自身科研实践经历持续分享关于Python、数据分析、特征工程、机器学习、深度学习、人工智能系列基础知识与案例。关注gzh:数据杂坛,获取数据和源码学习更多内容。

原文链接:

【Python深度学习系列】网格搜索神经网络超参数:丢弃率dropout(案例+源码)

相关推荐

其实TensorFlow真的很水无非就这30篇熬夜练

好的!以下是TensorFlow需要掌握的核心内容,用列表形式呈现,简洁清晰(含表情符号,<300字):1.基础概念与环境TensorFlow架构(计算图、会话->EagerE...

交叉验证和超参数调整:如何优化你的机器学习模型

准确预测Fitbit的睡眠得分在本文的前两部分中,我获取了Fitbit的睡眠数据并对其进行预处理,将这些数据分为训练集、验证集和测试集,除此之外,我还训练了三种不同的机器学习模型并比较了它们的性能。在...

机器学习交叉验证全指南:原理、类型与实战技巧

机器学习模型常常需要大量数据,但它们如何与实时新数据协同工作也同样关键。交叉验证是一种通过将数据集分成若干部分、在部分数据上训练模型、在其余数据上测试模型的方法,用来检验模型的表现。这有助于发现过拟合...

深度学习中的类别激活热图可视化

作者:ValentinaAlto编译:ronghuaiyang导读使用Keras实现图像分类中的激活热图的可视化,帮助更有针对性...

超强,必会的机器学习评估指标

大侠幸会,在下全网同名[算法金]0基础转AI上岸,多个算法赛Top[日更万日,让更多人享受智能乐趣]构建机器学习模型的关键步骤是检查其性能,这是通过使用验证指标来完成的。选择正确的验证指...

机器学习入门教程-第六课:监督学习与非监督学习

1.回顾与引入上节课我们谈到了机器学习的一些实战技巧,比如如何处理数据、选择模型以及调整参数。今天,我们将更深入地探讨机器学习的两大类:监督学习和非监督学习。2.监督学习监督学习就像是有老师的教学...

Python教程(三十八):机器学习基础

...

Python 模型部署不用愁!容器化实战,5 分钟搞定环境配置

你是不是也遇到过这种糟心事:花了好几天训练出的Python模型,在自己电脑上跑得顺顺当当,一放到服务器就各种报错。要么是Python版本不对,要么是依赖库冲突,折腾半天还是用不了。别再喊“我...

超全面讲透一个算法模型,高斯核!!

...

神经网络与传统统计方法的简单对比

传统的统计方法如...

AI 基础知识从0.1到0.2——用“房价预测”入门机器学习全流程

...

自回归滞后模型进行多变量时间序列预测

下图显示了关于不同类型葡萄酒销量的月度多元时间序列。每种葡萄酒类型都是时间序列中的一个变量。假设要预测其中一个变量。比如,sparklingwine。如何建立一个模型来进行预测呢?一种常见的方...

苹果AI策略:慢哲学——科技行业的“长期主义”试金石

苹果AI策略的深度原创分析,结合技术伦理、商业逻辑与行业博弈,揭示其“慢哲学”背后的战略智慧:一、反常之举:AI狂潮中的“逆行者”当科技巨头深陷AI军备竞赛,苹果的克制显得格格不入:功能延期:App...

时间序列预测全攻略,6大模型代码实操

如果你对数据分析感兴趣,希望学习更多的方法论,希望听听经验分享,欢迎移步宝藏公众号...

AI 基础知识从 0.4 到 0.5—— 计算机视觉之光 CNN

...

取消回复欢迎 发表评论: