百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

使用 Pytorch 训练 AlexNet 识别5种不同花朵

ztj100 2024-10-31 16:13 15 浏览 0 评论

1 数据

1.1 准备工作

新建一个文件夹AlexNet,在文件夹AlexNet新建一个文件夹flower_data,将下载后的数据解压并放到`文件夹flower_data。

1.2 数据下载

下载 Tensorflow 的花朵图片

http://download.tensorflow.org/example_images/flower_photos.tgz

1.3 数据分类

在`文件夹AlexNet`右键打开终端

gedit spile_data.py # 将 spile_data.py 拷入保存关闭
python spile_data.py # 运行 spile_data.py

spile_data.py

import os
from shutil import copy
import random

# 生成file
def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)


file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
mkfile('flower_data/train')
for cla in flower_class:
    mkfile('flower_data/train/'+cla)

mkfile('flower_data/val')
for cla in flower_class:
    mkfile('flower_data/val/'+cla)

split_rate = 0.1
for cla in flower_class:
    cla_path = file + '/' + cla + '/'
    images = os.listdir(cla_path)
    num = len(images)
    eval_index = random.sample(images, k=int(num*split_rate))
    for index, image in enumerate(images):
        # 随机分配为验证集
        if image in eval_index:
            image_path = cla_path + image
            new_path = 'flower_data/val/' + cla
            copy(image_path, new_path)
        # 非随机分配为训练集
        else:
            image_path = cla_path + image
            new_path = 'flower_data/train/' + cla
            copy(image_path, new_path)
        print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
    print()

print("processing done!")

2 模型

在文件夹AlexNet右键打开终端

gedit model.py # 将 model.py 拷入保存关闭

model.py

import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=5, init_weights=False):   
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(  #打包
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55] 自动舍去小数点后
            nn.ReLU(inplace=True), #inplace 可以载入更大模型
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27] kernel_num为原论文一半
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            #全链接
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1) #展平   或者view()
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') #何教授方法
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)  #正态分布赋值
                nn.init.constant_(m.bias, 0)

3 训练

在文件夹AlexNet右键打开终端

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time


#device : GPU 或 CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


#数据预处理
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪为224x224
                                 transforms.RandomHorizontalFlip(), # 水平翻转
                                 transforms.ToTensor(), # 转为张量
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),# 均值和方差为0.5
    "val": transforms.Compose([transforms.Resize((224, 224)), # 重置大小
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}


batch_size = 32 # 批次大小
data_root = os.getcwd() # 获取当前路径
image_path = data_root + "/flower_data/"  # 数据路径

train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                     transform=data_transform["train"]) # 加载训练数据集并预处理
train_num = len(train_dataset) # 训练数据集大小

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0) # 训练加载器

validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"]) # 验证数据集
val_num = len(validate_dataset) # 验证数据集大小
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=True,
                                              num_workers=0) # 验证加载器

print("训练数据集大小: ",train_num,"\n") # 3306
print("验证数据集大小: ",val_num,"\n") # 364



net = AlexNet(num_classes=5, init_weights=True) # 调用模型

net.to(device)

loss_function = nn.CrossEntropyLoss() # 损失函数:交叉熵
optimizer = optim.Adam(net.parameters(), lr=0.0002) #优化器 Adam
save_path = './AlexNet.pth' # 训练参数保存路径
best_acc = 0.0 # 训练过程中最高准确率

#开始进行训练和测试,训练一轮,测试一轮
for epoch in range(10):
    # 训练部分
    print(">>开始训练: ",epoch+1)
    net.train()    #训练dropout
    running_loss = 0.0
    t1 = time.perf_counter()
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        optimizer.zero_grad() # 梯度置0
        outputs = net(images.to(device)) 
        loss = loss_function(outputs, labels.to(device))
        loss.backward() # 反向传播
        optimizer.step()

        
        running_loss += loss.item() # 累加损失
        rate = (step + 1) / len(train_loader) # 训练进度
        a = "*" * int(rate * 50) # *数
        b = "." * int((1 - rate) * 50) # .数
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
    print()
    print(time.perf_counter()-t1) # 一个epoch花费的时间

    # 验证部分
    print(">>开始验证: ",epoch+1)
    net.eval()    #验证不需要dropout
    acc = 0.0  # 一个批次中分类正确个数
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            #print("outputs: \n",outputs,"\n")

            predict_y = torch.max(outputs, dim=1)[1]
            #print("predict_y: \n",predict_y,"\n")

            acc += (predict_y == val_labels.to(device)).sum().item() # 预测和标签一致,累加
        val_accurate = acc / val_num # 一个批次的准确率
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path) # 更新准确率最高的网络参数
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))

print('Finished Training')

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx 

#  {0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
cla_dict = dict((val, key) for key, val in flower_list.items())

# 将字典写入 json 文件
json_str = json.dumps(cla_dict, indent=4) # 字典转json
with open('class_indices.json', 'w') as json_file: # 对class_indices.json写入操作
    json_file.write(json_str) # 写入class_indices.json

4 测试

在网上下载一张jpg格式的图片,改名为sunflower.jpg,并放在文件夹AlexNet。

在文件夹AlexNet右键打开终端

gedit predict.py # 将 predict.py 拷入保存关闭
python predict.py # 运行 predict.py

predict.py

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json


# 加载图片
img = Image.open("./sunflower.jpg")  #验证太阳花

data_transform = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#img = Image.open("./roses.jpg")     #验证玫瑰花
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# 读取 class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# 调用模型
model = AlexNet(num_classes=5)
# 加载模型参数
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval() # 不用Dropout
with torch.no_grad():
    output = torch.squeeze(model(img)) # 压缩
    predict = torch.softmax(output, dim=0) # 生成概率
    predict_cla = torch.argmax(predict).numpy() # 最大值的索引
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

相关推荐

如何将数据仓库迁移到阿里云 AnalyticDB for PostgreSQL

阿里云AnalyticDBforPostgreSQL(以下简称ADBPG,即原HybridDBforPostgreSQL)为基于PostgreSQL内核的MPP架构的实时数据仓库服务,可以...

Python数据分析:探索性分析

写在前面如果你忘记了前面的文章,可以看看加深印象:Python数据处理...

CSP-J/S冲奖第21天:插入排序

...

C++基础语法梳理:算法丨十大排序算法(二)

本期是C++基础语法分享的第十六节,今天给大家来梳理一下十大排序算法后五个!归并排序...

C 语言的标准库有哪些

C语言的标准库并不是一个单一的实体,而是由一系列头文件(headerfiles)组成的集合。每个头文件声明了一组相关的函数、宏、类型和常量。程序员通过在代码中使用#include<...

[深度学习] ncnn安装和调用基础教程

1介绍ncnn是腾讯开发的一个为手机端极致优化的高性能神经网络前向计算框架,无第三方依赖,跨平台,但是通常都需要protobuf和opencv。ncnn目前已在腾讯多款应用中使用,如QQ,Qzon...

用rust实现经典的冒泡排序和快速排序

1.假设待排序数组如下letmutarr=[5,3,8,4,2,7,1];...

ncnn+PPYOLOv2首次结合!全网最详细代码解读来了

编辑:好困LRS【新智元导读】今天给大家安利一个宝藏仓库miemiedetection,该仓库集合了PPYOLO、PPYOLOv2、PPYOLOE三个算法pytorch实现三合一,其中的PPYOL...

C++特性使用建议

1.引用参数使用引用替代指针且所有不变的引用参数必须加上const。在C语言中,如果函数需要修改变量的值,参数必须为指针,如...

Qt4/5升级到Qt6吐血经验总结V202308

00:直观总结增加了很多轮子,同时原有模块拆分的也更细致,估计为了方便拓展个管理。把一些过度封装的东西移除了(比如同样的功能有多个函数),保证了只有一个函数执行该功能。把一些Qt5中兼容Qt4的方法废...

到底什么是C++11新特性,请看下文

C++11是一个比较大的更新,引入了很多新特性,以下是对这些特性的详细解释,帮助您快速理解C++11的内容1.自动类型推导(auto和decltype)...

掌握C++11这些特性,代码简洁性、安全性和性能轻松跃升!

C++11(又称C++0x)是C++编程语言的一次重大更新,引入了许多新特性,显著提升了代码简洁性、安全性和性能。以下是主要特性的分类介绍及示例:一、核心语言特性1.自动类型推导(auto)编译器自...

经典算法——凸包算法

凸包算法(ConvexHull)一、概念与问题描述凸包是指在平面上给定一组点,找到包含这些点的最小面积或最小周长的凸多边形。这个多边形没有任何内凹部分,即从一个多边形内的任意一点画一条线到多边形边界...

一起学习c++11——c++11中的新增的容器

c++11新增的容器1:array当时的初衷是希望提供一个在栈上分配的,定长数组,而且可以使用stl中的模板算法。array的用法如下:#include<string>#includ...

C++ 编程中的一些最佳实践

1.遵循代码简洁原则尽量避免冗余代码,通过模块化设计、清晰的命名和良好的结构,让代码更易于阅读和维护...

取消回复欢迎 发表评论: