百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

使用 Pytorch 训练 AlexNet 识别5种不同花朵

ztj100 2024-10-31 16:13 20 浏览 0 评论

1 数据

1.1 准备工作

新建一个文件夹AlexNet,在文件夹AlexNet新建一个文件夹flower_data,将下载后的数据解压并放到`文件夹flower_data。

1.2 数据下载

下载 Tensorflow 的花朵图片

http://download.tensorflow.org/example_images/flower_photos.tgz

1.3 数据分类

在`文件夹AlexNet`右键打开终端

gedit spile_data.py # 将 spile_data.py 拷入保存关闭
python spile_data.py # 运行 spile_data.py

spile_data.py

import os
from shutil import copy
import random

# 生成file
def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)


file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
mkfile('flower_data/train')
for cla in flower_class:
    mkfile('flower_data/train/'+cla)

mkfile('flower_data/val')
for cla in flower_class:
    mkfile('flower_data/val/'+cla)

split_rate = 0.1
for cla in flower_class:
    cla_path = file + '/' + cla + '/'
    images = os.listdir(cla_path)
    num = len(images)
    eval_index = random.sample(images, k=int(num*split_rate))
    for index, image in enumerate(images):
        # 随机分配为验证集
        if image in eval_index:
            image_path = cla_path + image
            new_path = 'flower_data/val/' + cla
            copy(image_path, new_path)
        # 非随机分配为训练集
        else:
            image_path = cla_path + image
            new_path = 'flower_data/train/' + cla
            copy(image_path, new_path)
        print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
    print()

print("processing done!")

2 模型

在文件夹AlexNet右键打开终端

gedit model.py # 将 model.py 拷入保存关闭

model.py

import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=5, init_weights=False):   
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(  #打包
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55] 自动舍去小数点后
            nn.ReLU(inplace=True), #inplace 可以载入更大模型
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27] kernel_num为原论文一半
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            #全链接
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1) #展平   或者view()
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') #何教授方法
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)  #正态分布赋值
                nn.init.constant_(m.bias, 0)

3 训练

在文件夹AlexNet右键打开终端

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time


#device : GPU 或 CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


#数据预处理
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪为224x224
                                 transforms.RandomHorizontalFlip(), # 水平翻转
                                 transforms.ToTensor(), # 转为张量
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),# 均值和方差为0.5
    "val": transforms.Compose([transforms.Resize((224, 224)), # 重置大小
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}


batch_size = 32 # 批次大小
data_root = os.getcwd() # 获取当前路径
image_path = data_root + "/flower_data/"  # 数据路径

train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                     transform=data_transform["train"]) # 加载训练数据集并预处理
train_num = len(train_dataset) # 训练数据集大小

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0) # 训练加载器

validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"]) # 验证数据集
val_num = len(validate_dataset) # 验证数据集大小
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=True,
                                              num_workers=0) # 验证加载器

print("训练数据集大小: ",train_num,"\n") # 3306
print("验证数据集大小: ",val_num,"\n") # 364



net = AlexNet(num_classes=5, init_weights=True) # 调用模型

net.to(device)

loss_function = nn.CrossEntropyLoss() # 损失函数:交叉熵
optimizer = optim.Adam(net.parameters(), lr=0.0002) #优化器 Adam
save_path = './AlexNet.pth' # 训练参数保存路径
best_acc = 0.0 # 训练过程中最高准确率

#开始进行训练和测试,训练一轮,测试一轮
for epoch in range(10):
    # 训练部分
    print(">>开始训练: ",epoch+1)
    net.train()    #训练dropout
    running_loss = 0.0
    t1 = time.perf_counter()
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        optimizer.zero_grad() # 梯度置0
        outputs = net(images.to(device)) 
        loss = loss_function(outputs, labels.to(device))
        loss.backward() # 反向传播
        optimizer.step()

        
        running_loss += loss.item() # 累加损失
        rate = (step + 1) / len(train_loader) # 训练进度
        a = "*" * int(rate * 50) # *数
        b = "." * int((1 - rate) * 50) # .数
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
    print()
    print(time.perf_counter()-t1) # 一个epoch花费的时间

    # 验证部分
    print(">>开始验证: ",epoch+1)
    net.eval()    #验证不需要dropout
    acc = 0.0  # 一个批次中分类正确个数
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            #print("outputs: \n",outputs,"\n")

            predict_y = torch.max(outputs, dim=1)[1]
            #print("predict_y: \n",predict_y,"\n")

            acc += (predict_y == val_labels.to(device)).sum().item() # 预测和标签一致,累加
        val_accurate = acc / val_num # 一个批次的准确率
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path) # 更新准确率最高的网络参数
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))

print('Finished Training')

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx 

#  {0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
cla_dict = dict((val, key) for key, val in flower_list.items())

# 将字典写入 json 文件
json_str = json.dumps(cla_dict, indent=4) # 字典转json
with open('class_indices.json', 'w') as json_file: # 对class_indices.json写入操作
    json_file.write(json_str) # 写入class_indices.json

4 测试

在网上下载一张jpg格式的图片,改名为sunflower.jpg,并放在文件夹AlexNet。

在文件夹AlexNet右键打开终端

gedit predict.py # 将 predict.py 拷入保存关闭
python predict.py # 运行 predict.py

predict.py

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json


# 加载图片
img = Image.open("./sunflower.jpg")  #验证太阳花

data_transform = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#img = Image.open("./roses.jpg")     #验证玫瑰花
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# 读取 class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# 调用模型
model = AlexNet(num_classes=5)
# 加载模型参数
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval() # 不用Dropout
with torch.no_grad():
    output = torch.squeeze(model(img)) # 压缩
    predict = torch.softmax(output, dim=0) # 生成概率
    predict_cla = torch.argmax(predict).numpy() # 最大值的索引
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

相关推荐

再说圆的面积-蒙特卡洛(蒙特卡洛方法求圆周率的matlab程序)

在微积分-圆的面积和周长(1)介绍微积分方法求解圆的面积,本文使用蒙特卡洛方法求解圆面积。...

python编程:如何使用python代码绘制出哪些常见的机器学习图像?

专栏推荐...

python创建分类器小结(pytorch分类数据集创建)

简介:分类是指利用数据的特性将其分成若干类型的过程。监督学习分类器就是用带标记的训练数据建立一个模型,然后对未知数据进行分类。...

matplotlib——绘制散点图(matplotlib散点图颜色和图例)

绘制散点图不同条件(维度)之间的内在关联关系观察数据的离散聚合程度...

python实现实时绘制数据(python如何绘制)

方法一importmatplotlib.pyplotaspltimportnumpyasnpimporttimefrommathimport*plt.ion()#...

简单学Python——matplotlib库3——绘制散点图

前面我们学习了用matplotlib绘制折线图,今天我们学习绘制散点图。其实简单的散点图与折线图的语法基本相同,只是作图函数由plot()变成了scatter()。下面就绘制一个散点图:import...

数据分析-相关性分析可视化(相关性分析数据处理)

前面介绍了相关性分析的原理、流程和常用的皮尔逊相关系数和斯皮尔曼相关系数,具体可以参考...

免费Python机器学习课程一:线性回归算法

学习线性回归的概念并从头开始在python中开发完整的线性回归算法最基本的机器学习算法必须是具有单个变量的线性回归算法。如今,可用的高级机器学习算法,库和技术如此之多,以至于线性回归似乎并不重要。但是...

用Python进行机器学习(2)之逻辑回归

前面介绍了线性回归,本次介绍的是逻辑回归。逻辑回归虽然名字里面带有“回归”两个字,但是它是一种分类算法,通常用于解决二分类问题,比如某个邮件是否是广告邮件,比如某个评价是否为正向的评价。逻辑回归也可以...

【Python机器学习系列】拟合和回归傻傻分不清?一文带你彻底搞懂

一、拟合和回归的区别拟合...

推荐2个十分好用的pandas数据探索分析神器

作者:俊欣来源:关于数据分析与可视化...

向量数据库:解锁大模型记忆的关键!选型指南+实战案例全解析

本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在...

用Python进行机器学习(11)-主成分分析PCA

我们在机器学习中有时候需要处理很多个参数,但是这些参数有时候彼此之间是有着各种关系的,这个时候我们就会想:是否可以找到一种方式来降低参数的个数呢?这就是今天我们要介绍的主成分分析,英文是Princip...

神经网络基础深度解析:从感知机到反向传播

本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在...

Python实现基于机器学习的RFM模型

CDA数据分析师出品作者:CDALevelⅠ持证人岗位:数据分析师行业:大数据...

取消回复欢迎 发表评论: