百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

[深度学习] Pytorch模型转换为onnx模型笔记

ztj100 2025-01-29 19:16 20 浏览 0 评论


本文主要介绍将pytorch模型准确导出为可用的onnx模型。以方便OpenCV Dnn,NCNN,MNN,TensorRT等框架调用。所有代码见:?Python-Study-Notes??

文章目录

  • 1 使用说明
  • 1.1 读取模型
  • 1.2 检测图像
  • 1.3 导出为onnx模型
  • 1.4 模型测试
  • 1.5 模型简化
  • 1.6 全部代码
  • 2 参考

1 使用说明

本文示例为调用pytorch预训练的mobilenetv2模型,将其导出为onnx模型。主要步骤如下:

  1. 读取模型
  2. 检测图像
  3. 导出为onnx模型
  4. 模型测试
  5. 模型简化
# 需要调用的头文件
import torch
from torchvision import models
import cv2
import numpy as np
from torchsummary import summary
import onnxruntime
from onnxsim import simplify
import onnx
from matplotlib import pyplot as plt

# 判断使用CPU还是GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

1.1 读取模型

该部分主要为调用训练好的模型。主要内容如下

  1. 直接读取预训练模型
  2. 将模型转换为推理模型
  3. 查看模型的结构
# ----- 1 读取模型
print("----- 1 读取模型 -----")
# 载入模型并读取权重
model = models.mobilenet_v2(pretrained=True)
# 将模型转换为推理模式
model.eval()
# 查看模型的结构,(3,224,224)为模型的图像输入
summary(model, (3, 224, 224))
----- 1 读取模型 -----
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1         [-1, 32, 112, 112]             864
       BatchNorm2d-2         [-1, 32, 112, 112]              64
             ReLU6-3         [-1, 32, 112, 112]               0
            Conv2d-4         [-1, 32, 112, 112]             288
       BatchNorm2d-5         [-1, 32, 112, 112]              64
             ReLU6-6         [-1, 32, 112, 112]               0
            Conv2d-7         [-1, 16, 112, 112]             512
       BatchNorm2d-8         [-1, 16, 112, 112]              32
  InvertedResidual-9         [-1, 16, 112, 112]               0
           Conv2d-10         [-1, 96, 112, 112]           1,536
      BatchNorm2d-11         [-1, 96, 112, 112]             192
            ReLU6-12         [-1, 96, 112, 112]               0
           Conv2d-13           [-1, 96, 56, 56]             864
      BatchNorm2d-14           [-1, 96, 56, 56]             192
            ReLU6-15           [-1, 96, 56, 56]               0
           Conv2d-16           [-1, 24, 56, 56]           2,304
      BatchNorm2d-17           [-1, 24, 56, 56]              48
 InvertedResidual-18           [-1, 24, 56, 56]               0
           Conv2d-19          [-1, 144, 56, 56]           3,456
      BatchNorm2d-20          [-1, 144, 56, 56]             288
            ReLU6-21          [-1, 144, 56, 56]               0
           Conv2d-22          [-1, 144, 56, 56]           1,296
      BatchNorm2d-23          [-1, 144, 56, 56]             288
            ReLU6-24          [-1, 144, 56, 56]               0
           Conv2d-25           [-1, 24, 56, 56]           3,456
      BatchNorm2d-26           [-1, 24, 56, 56]              48
 InvertedResidual-27           [-1, 24, 56, 56]               0
           Conv2d-28          [-1, 144, 56, 56]           3,456
      BatchNorm2d-29          [-1, 144, 56, 56]             288
            ReLU6-30          [-1, 144, 56, 56]               0
           Conv2d-31          [-1, 144, 28, 28]           1,296
      BatchNorm2d-32          [-1, 144, 28, 28]             288
            ReLU6-33          [-1, 144, 28, 28]               0
           Conv2d-34           [-1, 32, 28, 28]           4,608
      BatchNorm2d-35           [-1, 32, 28, 28]              64
 InvertedResidual-36           [-1, 32, 28, 28]               0
           Conv2d-37          [-1, 192, 28, 28]           6,144
      BatchNorm2d-38          [-1, 192, 28, 28]             384
            ReLU6-39          [-1, 192, 28, 28]               0
           Conv2d-40          [-1, 192, 28, 28]           1,728
      BatchNorm2d-41          [-1, 192, 28, 28]             384
            ReLU6-42          [-1, 192, 28, 28]               0
           Conv2d-43           [-1, 32, 28, 28]           6,144
      BatchNorm2d-44           [-1, 32, 28, 28]              64
 InvertedResidual-45           [-1, 32, 28, 28]               0
           Conv2d-46          [-1, 192, 28, 28]           6,144
      BatchNorm2d-47          [-1, 192, 28, 28]             384
            ReLU6-48          [-1, 192, 28, 28]               0
           Conv2d-49          [-1, 192, 28, 28]           1,728
      BatchNorm2d-50          [-1, 192, 28, 28]             384
            ReLU6-51          [-1, 192, 28, 28]               0
           Conv2d-52           [-1, 32, 28, 28]           6,144
      BatchNorm2d-53           [-1, 32, 28, 28]              64
 InvertedResidual-54           [-1, 32, 28, 28]               0
           Conv2d-55          [-1, 192, 28, 28]           6,144
      BatchNorm2d-56          [-1, 192, 28, 28]             384
            ReLU6-57          [-1, 192, 28, 28]               0
           Conv2d-58          [-1, 192, 14, 14]           1,728
      BatchNorm2d-59          [-1, 192, 14, 14]             384
            ReLU6-60          [-1, 192, 14, 14]               0
           Conv2d-61           [-1, 64, 14, 14]          12,288
      BatchNorm2d-62           [-1, 64, 14, 14]             128
 InvertedResidual-63           [-1, 64, 14, 14]               0
           Conv2d-64          [-1, 384, 14, 14]          24,576
      BatchNorm2d-65          [-1, 384, 14, 14]             768
            ReLU6-66          [-1, 384, 14, 14]               0
           Conv2d-67          [-1, 384, 14, 14]           3,456
      BatchNorm2d-68          [-1, 384, 14, 14]             768
            ReLU6-69          [-1, 384, 14, 14]               0
           Conv2d-70           [-1, 64, 14, 14]          24,576
      BatchNorm2d-71           [-1, 64, 14, 14]             128
 InvertedResidual-72           [-1, 64, 14, 14]               0
           Conv2d-73          [-1, 384, 14, 14]          24,576
      BatchNorm2d-74          [-1, 384, 14, 14]             768
            ReLU6-75          [-1, 384, 14, 14]               0
           Conv2d-76          [-1, 384, 14, 14]           3,456
      BatchNorm2d-77          [-1, 384, 14, 14]             768
            ReLU6-78          [-1, 384, 14, 14]               0
           Conv2d-79           [-1, 64, 14, 14]          24,576
      BatchNorm2d-80           [-1, 64, 14, 14]             128
 InvertedResidual-81           [-1, 64, 14, 14]               0
           Conv2d-82          [-1, 384, 14, 14]          24,576
      BatchNorm2d-83          [-1, 384, 14, 14]             768
            ReLU6-84          [-1, 384, 14, 14]               0
           Conv2d-85          [-1, 384, 14, 14]           3,456
      BatchNorm2d-86          [-1, 384, 14, 14]             768
            ReLU6-87          [-1, 384, 14, 14]               0
           Conv2d-88           [-1, 64, 14, 14]          24,576
      BatchNorm2d-89           [-1, 64, 14, 14]             128
 InvertedResidual-90           [-1, 64, 14, 14]               0
           Conv2d-91          [-1, 384, 14, 14]          24,576
      BatchNorm2d-92          [-1, 384, 14, 14]             768
            ReLU6-93          [-1, 384, 14, 14]               0
           Conv2d-94          [-1, 384, 14, 14]           3,456
      BatchNorm2d-95          [-1, 384, 14, 14]             768
            ReLU6-96          [-1, 384, 14, 14]               0
           Conv2d-97           [-1, 96, 14, 14]          36,864
      BatchNorm2d-98           [-1, 96, 14, 14]             192
 InvertedResidual-99           [-1, 96, 14, 14]               0
          Conv2d-100          [-1, 576, 14, 14]          55,296
     BatchNorm2d-101          [-1, 576, 14, 14]           1,152
           ReLU6-102          [-1, 576, 14, 14]               0
          Conv2d-103          [-1, 576, 14, 14]           5,184
     BatchNorm2d-104          [-1, 576, 14, 14]           1,152
           ReLU6-105          [-1, 576, 14, 14]               0
          Conv2d-106           [-1, 96, 14, 14]          55,296
     BatchNorm2d-107           [-1, 96, 14, 14]             192
InvertedResidual-108           [-1, 96, 14, 14]               0
          Conv2d-109          [-1, 576, 14, 14]          55,296
     BatchNorm2d-110          [-1, 576, 14, 14]           1,152
           ReLU6-111          [-1, 576, 14, 14]               0
          Conv2d-112          [-1, 576, 14, 14]           5,184
     BatchNorm2d-113          [-1, 576, 14, 14]           1,152
           ReLU6-114          [-1, 576, 14, 14]               0
          Conv2d-115           [-1, 96, 14, 14]          55,296
     BatchNorm2d-116           [-1, 96, 14, 14]             192
InvertedResidual-117           [-1, 96, 14, 14]               0
          Conv2d-118          [-1, 576, 14, 14]          55,296
     BatchNorm2d-119          [-1, 576, 14, 14]           1,152
           ReLU6-120          [-1, 576, 14, 14]               0
          Conv2d-121            [-1, 576, 7, 7]           5,184
     BatchNorm2d-122            [-1, 576, 7, 7]           1,152
           ReLU6-123            [-1, 576, 7, 7]               0
          Conv2d-124            [-1, 160, 7, 7]          92,160
     BatchNorm2d-125            [-1, 160, 7, 7]             320
InvertedResidual-126            [-1, 160, 7, 7]               0
          Conv2d-127            [-1, 960, 7, 7]         153,600
     BatchNorm2d-128            [-1, 960, 7, 7]           1,920
           ReLU6-129            [-1, 960, 7, 7]               0
          Conv2d-130            [-1, 960, 7, 7]           8,640
     BatchNorm2d-131            [-1, 960, 7, 7]           1,920
           ReLU6-132            [-1, 960, 7, 7]               0
          Conv2d-133            [-1, 160, 7, 7]         153,600
     BatchNorm2d-134            [-1, 160, 7, 7]             320
InvertedResidual-135            [-1, 160, 7, 7]               0
          Conv2d-136            [-1, 960, 7, 7]         153,600
     BatchNorm2d-137            [-1, 960, 7, 7]           1,920
           ReLU6-138            [-1, 960, 7, 7]               0
          Conv2d-139            [-1, 960, 7, 7]           8,640
     BatchNorm2d-140            [-1, 960, 7, 7]           1,920
           ReLU6-141            [-1, 960, 7, 7]               0
          Conv2d-142            [-1, 160, 7, 7]         153,600
     BatchNorm2d-143            [-1, 160, 7, 7]             320
InvertedResidual-144            [-1, 160, 7, 7]               0
          Conv2d-145            [-1, 960, 7, 7]         153,600
     BatchNorm2d-146            [-1, 960, 7, 7]           1,920
           ReLU6-147            [-1, 960, 7, 7]               0
          Conv2d-148            [-1, 960, 7, 7]           8,640
     BatchNorm2d-149            [-1, 960, 7, 7]           1,920
           ReLU6-150            [-1, 960, 7, 7]               0
          Conv2d-151            [-1, 320, 7, 7]         307,200
     BatchNorm2d-152            [-1, 320, 7, 7]             640
InvertedResidual-153            [-1, 320, 7, 7]               0
          Conv2d-154           [-1, 1280, 7, 7]         409,600
     BatchNorm2d-155           [-1, 1280, 7, 7]           2,560
           ReLU6-156           [-1, 1280, 7, 7]               0
         Dropout-157                 [-1, 1280]               0
          Linear-158                 [-1, 1000]       1,281,000
================================================================
Total params: 3,504,872
Trainable params: 3,504,872
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.57
Forward/backward pass size (MB): 152.87
Params size (MB): 13.37
Estimated Total Size (MB): 166.81
----------------------------------------------------------------

1.2 检测图像

该部分主要为检测图像,查看模型结果。一般来说pytorch导出的onnx模型都是用于C++调用,所以基于OpenCV直接读取图像,进行图像通道转换以及图像归一化以模拟实际C++调用情况,而不是用pillow和pytorch的transform。通常C++提供的图像都是经由OpenCV调用而来。主要内容如下:

  1. 基于OpenCV读取图像,进行通道转换
  2. 将图像进行归一化
  3. 进行模型推理,查看结果
# ----- 2 检测图像
print("----- 2 检测图像 -----")
# 待检测图像路径 
img_path = './image/rabbit.jpg'

# 读取图像
img = cv2.imread(img_path)
# 图像通道转换
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 展示图像
plt.imshow(img)
plt.show()
# 图像大小重置为模型输入图像大小
img = cv2.resize(img, (224, 224))

# 图像归一化
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
img = np.array((img / 255.0 - mean) / std, dtype=np.float32)

# 图像通道转换
img = img.transpose([2, 0, 1])
# 获得pytorch需要的输入图像格式NCHW
img_ = torch.from_numpy(img).unsqueeze(0)
img_ = img_.to(device)
# 推理
outputs = model(img_)

# 得到预测结果,并且按概率从大到小排序
_, indices = torch.sort(outputs, descending=True)
# 返回top5每个预测标签的百分数
percentage = torch.nn.functional.softmax(outputs, dim=1)[0] * 100
print(["预测标签为: {},预测概率为:{};".format(idx, percentage[idx].item()) for idx in indices[0][:5]])

# 保存/载入整个pytorch模型
# torch.save(model, 'model.ckpt')
# model = torch.load('model.ckpt')

# 仅仅保存/载入pytorch模型的参数
# torch.save(model.state_dict(), 'params.ckpt')
# model.load_state_dict(torch.load('params.ckpt'))
----- 2 检测图像 -----
['预测标签为: 331,预测概率为:54.409969329833984;', '预测标签为: 330,预测概率为:33.62083435058594;', '预测标签为: 332,预测概率为:11.84182071685791;', '预测标签为: 263,预测概率为:0.05221949517726898;', '预测标签为: 264,预测概率为:0.027525480836629868;']

1.3 导出为onnx模型

该部分主要为导出onnx模型,两行代码就可以搞定,onnx模型导出路径为当前目录下mobilenet_v2.onnx。具体如下:

x = torch.rand(1, 3, 224, 224)
torch_out = torch.onnx._export(model, x, output_name, export_params=True,
                               input_names=["input"], output_names=["output"])
# ---- 3 导出为onnx模型
print("----- 3 导出为onnx模型 -----")
# An example input you would normally provide to your model's forward() method
# x为输入图像,格式为pytorch的NCHW格式;1为图像数一般不需要修改;3为通道数;224,224为图像高宽;
x = torch.rand(1, 3, 224, 224)
# 模型输出名
output_name = "mobilenet_v2.onnx"
# Export the model
# 导出为onnx模型
# model为模型,x为模型输入,"mobilenet_v2.onnx"为onnx输出名,export_params表示是否保存模型参数
# input_names为onnx模型输入节点名字,需要输入列表
# output_names为onnx模型输出节点名字,需要输入列表;如果是多输出修改为output_names=["output1","output2"]
torch_out = torch.onnx._export(model, x, output_name, export_params=True,
                               input_names=["input"], output_names=["output"])
print("模型导出成功")
----- 3 导出为onnx模型 -----
模型导出成功

1.4 模型测试

该部分主要为测试模型,一般可以跳过,不需要这部分代码,通常模型转换不会出错。另外onnx模型可以通过??Netron??查看结构。

# ---- 4 模型测试(可跳过)
print("----- 4 模型测试 -----")


# 可以跳过该步骤,一般不会有问题

# 检查输出
def check_onnx_output(filename, input_data, torch_output):
    session = onnxruntime.InferenceSession(filename)
    input_name = session.get_inputs()[0].name
    result = session.run([], {input_name: input_data.numpy()})
    for test_result, gold_result in zip(result, torch_output.values()):
        np.testing.assert_almost_equal(
            gold_result.cpu().numpy(), test_result, decimal=3,
        )
    return result


# 检查模型
def check_onnx_model(model, onnx_filename, input_image):
    with torch.no_grad():
        torch_out = {"output": model(input_image)}
    check_onnx_output(onnx_filename, input_image, torch_out)
    onnx_model = onnx.load(onnx_filename)
    onnx.checker.check_model(onnx_model)
    print("模型测试成功")
    return onnx_model

# 检测导出的onnx模型是否完整
# 一般出现问题程序直接报错,不过很少出现问题
onnx_model = check_onnx_model(model, output_name, x)
----- 4 模型测试 -----
模型测试成功

1.5 模型简化

一般来说导出后的onnx模型会有一堆冗余操作,需要简化。推荐使用??onnx-simplifier???进行onnx模型简化。onnx简化模型导出路径为当前目录下mobilenet_v2.onnxsim.onnx
调用onnx-simplifier有三种办法:

  1. 调用代码,调用onnx-simplifier的simplify接口
  2. 命令行简化,直接输入python3 -m onnxsim input_onnx_model output_onnx_model
  3. 在线调用,调用??onnx-simplifier???作者的??https://convertmodel.com/??直接进行模型简化。

具体来说推荐第三种在线使用,第三种在线调用方便,还能将onnx模型转换为ncnn,mnn等模型格式。

P.S. onnx-simplifier对于高版本pytorch不那么支持,转换可能失败,所以设置skip_fuse_bn=True跳过融合bn层。这种情况下onnx-simplifier转换出来的onnx模型可能比转换前的模型大,原因是补充了shape信息。

# ----- 5 模型简化
print("----- 5 模型简化 -----")
# 基于onnx-simplifier简化模型,https://github.com/daquexian/onnx-simplifier
# 也可以命令行输入python3 -m onnxsim input_onnx_model output_onnx_model
# 或者使用在线网站直接转换https://convertmodel.com/

# 输出模型名
filename = output_name + "sim.onnx"
# 简化模型
# 设置skip_fuse_bn=True表示跳过融合bn层,pytorch高版本融合bn层会出错
simplified_model, check = simplify(onnx_model, skip_fuse_bn=True)
onnx.save_model(simplified_model, filename)
onnx.checker.check_model(simplified_model)
# 如果出错
assert check, "简化模型失败"
print("模型简化成功")
----- 5 模型简化 -----
模型简化成功

1.6 全部代码

全部工程代码如下

# -*- coding: utf-8 -*-
"""
Created on Tue Dec  8 19:44:42 2020

@author: luohenyueji
"""

import torch
from torchvision import models
import cv2
import numpy as np
from torchsummary import summary
import onnxruntime
from onnxsim import simplify
import onnx
from matplotlib import pyplot as plt

# 判断使用CPU还是GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# ----- 1 读取模型
print("----- 1 读取模型 -----")
# 载入模型并读取权重
model = models.mobilenet_v2(pretrained=True)
# 将模型转换为推理模式
model.eval()
# 查看模型的结构,(3,224,224)为模型的图像输入
# summary(model, (3, 224, 224))

# ----- 2 检测图像
print("----- 2 检测图像 -----")
# 待检测图像路径 
img_path = './image/rabbit.jpg'

# 读取图像
img = cv2.imread(img_path)
# 图像通道转换
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# 展示图像
# plt.imshow(img)
# plt.show()
# 图像大小重置为模型输入图像大小
img = cv2.resize(img, (224, 224))

# 图像归一化
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
img = np.array((img / 255.0 - mean) / std, dtype=np.float32)

# 图像通道转换
img = img.transpose([2, 0, 1])
# 获得pytorch需要的输入图像格式NCHW
img_ = torch.from_numpy(img).unsqueeze(0)
img_ = img_.to(device)
# 推理
outputs = model(img_)

# 得到预测结果,并且按概率从大到小排序
_, indices = torch.sort(outputs, descending=True)
# 返回top5每个预测标签的百分数
percentage = torch.nn.functional.softmax(outputs, dim=1)[0] * 100
print(["预测标签为: {},预测概率为:{};".format(idx, percentage[idx].item()) for idx in indices[0][:5]])

# 保存/载入整个pytorch模型
# torch.save(model, 'model.ckpt')
# model = torch.load('model.ckpt')

# 仅仅保存/载入pytorch模型的参数
# torch.save(model.state_dict(), 'params.ckpt')
# model.load_state_dict(torch.load('params.ckpt'))

# ---- 3 导出为onnx模型
print("----- 3 导出为onnx模型 -----")
# An example input you would normally provide to your model's forward() method
# x为输入图像,格式为pytorch的NCHW格式;1为图像数一般不需要修改;3为通道数;224,224为图像高宽;
x = torch.rand(1, 3, 224, 224)
# 模型输出名
output_name = "mobilenet_v2.onnx"
# Export the model
# 导出为onnx模型
# model为模型,x为模型输入,"mobilenet_v2.onnx"为onnx输出名,export_params表示是否保存模型参数
# input_names为onnx模型输入节点名字,需要输入列表
# output_names为onnx模型输出节点名字,需要输入列表;如果是多输出修改为output_names=["output1","output2"]
torch_out = torch.onnx._export(model, x, output_name, export_params=True,
                               input_names=["input"], output_names=["output"])
print("模型导出成功")

# ---- 4 模型测试(可跳过)
print("----- 4 模型测试 -----")


# 可以跳过该步骤,一般不会有问题

# 检查输出
def check_onnx_output(filename, input_data, torch_output):
    session = onnxruntime.InferenceSession(filename)
    input_name = session.get_inputs()[0].name
    result = session.run([], {input_name: input_data.numpy()})
    for test_result, gold_result in zip(result, torch_output.values()):
        np.testing.assert_almost_equal(
            gold_result.cpu().numpy(), test_result, decimal=3,
        )
    return result


# 检查模型
def check_onnx_model(model, onnx_filename, input_image):
    with torch.no_grad():
        torch_out = {"output": model(input_image)}
    check_onnx_output(onnx_filename, input_image, torch_out)
    onnx_model = onnx.load(onnx_filename)
    onnx.checker.check_model(onnx_model)
    print("模型测试成功")
    return onnx_model


# 检测导出的onnx模型是否完整
# 一般出现问题程序直接报错,不过很少出现问题
onnx_model = check_onnx_model(model, output_name, x)

# ----- 5 模型简化
print("----- 5 模型简化 -----")
# 基于onnx-simplifier简化模型,https://github.com/daquexian/onnx-simplifier
# 也可以命令行输入python3 -m onnxsim input_onnx_model output_onnx_model
# 或者使用在线网站直接转换https://convertmodel.com/

# 输出模型名
filename = output_name + "sim.onnx"
# 简化模型
# 设置skip_fuse_bn=True表示跳过融合bn层,pytorch高版本融合bn层会出错
simplified_model, check = simplify(onnx_model, skip_fuse_bn=True)
onnx.save_model(simplified_model, filename)
onnx.checker.check_model(simplified_model)
# 如果出错
assert check, "简化模型失败"
print("模型简化成功")
----- 1 读取模型 -----
----- 2 检测图像 -----
['预测标签为: 331,预测概率为:54.409969329833984;', '预测标签为: 330,预测概率为:33.62083435058594;', '预测标签为: 332,预测概率为:11.84182071685791;', '预测标签为: 263,预测概率为:0.05221949517726898;', '预测标签为: 264,预测概率为:0.027525480836629868;']
----- 3 导出为onnx模型 -----
模型导出成功
----- 4 模型测试 -----
模型测试成功
----- 5 模型简化 -----
模型简化成功

2 参考

  • ??Netron??
  • ??use ncnn with pytorch or onnx??
  • ??PyTorch to CoreML model conversion??
  • ??onnx-simplifier??
  • ??https://convertmodel.com/??

相关推荐

其实TensorFlow真的很水无非就这30篇熬夜练

好的!以下是TensorFlow需要掌握的核心内容,用列表形式呈现,简洁清晰(含表情符号,<300字):1.基础概念与环境TensorFlow架构(计算图、会话->EagerE...

交叉验证和超参数调整:如何优化你的机器学习模型

准确预测Fitbit的睡眠得分在本文的前两部分中,我获取了Fitbit的睡眠数据并对其进行预处理,将这些数据分为训练集、验证集和测试集,除此之外,我还训练了三种不同的机器学习模型并比较了它们的性能。在...

机器学习交叉验证全指南:原理、类型与实战技巧

机器学习模型常常需要大量数据,但它们如何与实时新数据协同工作也同样关键。交叉验证是一种通过将数据集分成若干部分、在部分数据上训练模型、在其余数据上测试模型的方法,用来检验模型的表现。这有助于发现过拟合...

深度学习中的类别激活热图可视化

作者:ValentinaAlto编译:ronghuaiyang导读使用Keras实现图像分类中的激活热图的可视化,帮助更有针对性...

超强,必会的机器学习评估指标

大侠幸会,在下全网同名[算法金]0基础转AI上岸,多个算法赛Top[日更万日,让更多人享受智能乐趣]构建机器学习模型的关键步骤是检查其性能,这是通过使用验证指标来完成的。选择正确的验证指...

机器学习入门教程-第六课:监督学习与非监督学习

1.回顾与引入上节课我们谈到了机器学习的一些实战技巧,比如如何处理数据、选择模型以及调整参数。今天,我们将更深入地探讨机器学习的两大类:监督学习和非监督学习。2.监督学习监督学习就像是有老师的教学...

Python教程(三十八):机器学习基础

...

Python 模型部署不用愁!容器化实战,5 分钟搞定环境配置

你是不是也遇到过这种糟心事:花了好几天训练出的Python模型,在自己电脑上跑得顺顺当当,一放到服务器就各种报错。要么是Python版本不对,要么是依赖库冲突,折腾半天还是用不了。别再喊“我...

超全面讲透一个算法模型,高斯核!!

...

神经网络与传统统计方法的简单对比

传统的统计方法如...

AI 基础知识从0.1到0.2——用“房价预测”入门机器学习全流程

...

自回归滞后模型进行多变量时间序列预测

下图显示了关于不同类型葡萄酒销量的月度多元时间序列。每种葡萄酒类型都是时间序列中的一个变量。假设要预测其中一个变量。比如,sparklingwine。如何建立一个模型来进行预测呢?一种常见的方...

苹果AI策略:慢哲学——科技行业的“长期主义”试金石

苹果AI策略的深度原创分析,结合技术伦理、商业逻辑与行业博弈,揭示其“慢哲学”背后的战略智慧:一、反常之举:AI狂潮中的“逆行者”当科技巨头深陷AI军备竞赛,苹果的克制显得格格不入:功能延期:App...

时间序列预测全攻略,6大模型代码实操

如果你对数据分析感兴趣,希望学习更多的方法论,希望听听经验分享,欢迎移步宝藏公众号...

AI 基础知识从 0.4 到 0.5—— 计算机视觉之光 CNN

...

取消回复欢迎 发表评论: