百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术分类 > 正文

使用 Pytorch 训练 AlexNet 识别5种不同花朵

ztj100 2024-10-31 16:13 35 浏览 0 评论

1 数据

1.1 准备工作

新建一个文件夹AlexNet,在文件夹AlexNet新建一个文件夹flower_data,将下载后的数据解压并放到`文件夹flower_data。

1.2 数据下载

下载 Tensorflow 的花朵图片

http://download.tensorflow.org/example_images/flower_photos.tgz

1.3 数据分类

在`文件夹AlexNet`右键打开终端

gedit spile_data.py # 将 spile_data.py 拷入保存关闭
python spile_data.py # 运行 spile_data.py

spile_data.py

import os
from shutil import copy
import random

# 生成file
def mkfile(file):
    if not os.path.exists(file):
        os.makedirs(file)


file = 'flower_data/flower_photos'
flower_class = [cla for cla in os.listdir(file) if ".txt" not in cla]
mkfile('flower_data/train')
for cla in flower_class:
    mkfile('flower_data/train/'+cla)

mkfile('flower_data/val')
for cla in flower_class:
    mkfile('flower_data/val/'+cla)

split_rate = 0.1
for cla in flower_class:
    cla_path = file + '/' + cla + '/'
    images = os.listdir(cla_path)
    num = len(images)
    eval_index = random.sample(images, k=int(num*split_rate))
    for index, image in enumerate(images):
        # 随机分配为验证集
        if image in eval_index:
            image_path = cla_path + image
            new_path = 'flower_data/val/' + cla
            copy(image_path, new_path)
        # 非随机分配为训练集
        else:
            image_path = cla_path + image
            new_path = 'flower_data/train/' + cla
            copy(image_path, new_path)
        print("\r[{}] processing [{}/{}]".format(cla, index+1, num), end="")  # processing bar
    print()

print("processing done!")

2 模型

在文件夹AlexNet右键打开终端

gedit model.py # 将 model.py 拷入保存关闭

model.py

import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=5, init_weights=False):   
        super(AlexNet, self).__init__()
        self.features = nn.Sequential(  #打包
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55] 自动舍去小数点后
            nn.ReLU(inplace=True), #inplace 可以载入更大模型
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27] kernel_num为原论文一半
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),
            #全链接
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1) #展平   或者view()
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') #何教授方法
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)  #正态分布赋值
                nn.init.constant_(m.bias, 0)

3 训练

在文件夹AlexNet右键打开终端

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time


#device : GPU 或 CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)


#数据预处理
data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪为224x224
                                 transforms.RandomHorizontalFlip(), # 水平翻转
                                 transforms.ToTensor(), # 转为张量
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),# 均值和方差为0.5
    "val": transforms.Compose([transforms.Resize((224, 224)), # 重置大小
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}


batch_size = 32 # 批次大小
data_root = os.getcwd() # 获取当前路径
image_path = data_root + "/flower_data/"  # 数据路径

train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                     transform=data_transform["train"]) # 加载训练数据集并预处理
train_num = len(train_dataset) # 训练数据集大小

train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0) # 训练加载器

validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"]) # 验证数据集
val_num = len(validate_dataset) # 验证数据集大小
validate_loader = torch.utils.data.DataLoader(validate_dataset,
                                              batch_size=batch_size, shuffle=True,
                                              num_workers=0) # 验证加载器

print("训练数据集大小: ",train_num,"\n") # 3306
print("验证数据集大小: ",val_num,"\n") # 364



net = AlexNet(num_classes=5, init_weights=True) # 调用模型

net.to(device)

loss_function = nn.CrossEntropyLoss() # 损失函数:交叉熵
optimizer = optim.Adam(net.parameters(), lr=0.0002) #优化器 Adam
save_path = './AlexNet.pth' # 训练参数保存路径
best_acc = 0.0 # 训练过程中最高准确率

#开始进行训练和测试,训练一轮,测试一轮
for epoch in range(10):
    # 训练部分
    print(">>开始训练: ",epoch+1)
    net.train()    #训练dropout
    running_loss = 0.0
    t1 = time.perf_counter()
    for step, data in enumerate(train_loader, start=0):
        images, labels = data
        optimizer.zero_grad() # 梯度置0
        outputs = net(images.to(device)) 
        loss = loss_function(outputs, labels.to(device))
        loss.backward() # 反向传播
        optimizer.step()

        
        running_loss += loss.item() # 累加损失
        rate = (step + 1) / len(train_loader) # 训练进度
        a = "*" * int(rate * 50) # *数
        b = "." * int((1 - rate) * 50) # .数
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(int(rate * 100), a, b, loss), end="")
    print()
    print(time.perf_counter()-t1) # 一个epoch花费的时间

    # 验证部分
    print(">>开始验证: ",epoch+1)
    net.eval()    #验证不需要dropout
    acc = 0.0  # 一个批次中分类正确个数
    with torch.no_grad():
        for val_data in validate_loader:
            val_images, val_labels = val_data
            outputs = net(val_images.to(device))
            #print("outputs: \n",outputs,"\n")

            predict_y = torch.max(outputs, dim=1)[1]
            #print("predict_y: \n",predict_y,"\n")

            acc += (predict_y == val_labels.to(device)).sum().item() # 预测和标签一致,累加
        val_accurate = acc / val_num # 一个批次的准确率
        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path) # 更新准确率最高的网络参数
        print('[epoch %d] train_loss: %.3f  test_accuracy: %.3f' %
              (epoch + 1, running_loss / step, val_accurate))

print('Finished Training')

# {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
flower_list = train_dataset.class_to_idx 

#  {0: 'daisy', 1: 'dandelion', 2: 'roses', 3: 'sunflowers', 4: 'tulips'}
cla_dict = dict((val, key) for key, val in flower_list.items())

# 将字典写入 json 文件
json_str = json.dumps(cla_dict, indent=4) # 字典转json
with open('class_indices.json', 'w') as json_file: # 对class_indices.json写入操作
    json_file.write(json_str) # 写入class_indices.json

4 测试

在网上下载一张jpg格式的图片,改名为sunflower.jpg,并放在文件夹AlexNet。

在文件夹AlexNet右键打开终端

gedit predict.py # 将 predict.py 拷入保存关闭
python predict.py # 运行 predict.py

predict.py

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json


# 加载图片
img = Image.open("./sunflower.jpg")  #验证太阳花

data_transform = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

#img = Image.open("./roses.jpg")     #验证玫瑰花
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# 读取 class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# 调用模型
model = AlexNet(num_classes=5)
# 加载模型参数
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))
model.eval() # 不用Dropout
with torch.no_grad():
    output = torch.squeeze(model(img)) # 压缩
    predict = torch.softmax(output, dim=0) # 生成概率
    predict_cla = torch.argmax(predict).numpy() # 最大值的索引
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

相关推荐

sharding-jdbc实现`分库分表`与`读写分离`

一、前言本文将基于以下环境整合...

三分钟了解mysql中主键、外键、非空、唯一、默认约束是什么

在数据库中,数据表是数据库中最重要、最基本的操作对象,是数据存储的基本单位。数据表被定义为列的集合,数据在表中是按照行和列的格式来存储的。每一行代表一条唯一的记录,每一列代表记录中的一个域。...

MySQL8行级锁_mysql如何加行级锁

MySQL8行级锁版本:8.0.34基本概念...

mysql使用小技巧_mysql使用入门

1、MySQL中有许多很实用的函数,好好利用它们可以省去很多时间:group_concat()将取到的值用逗号连接,可以这么用:selectgroup_concat(distinctid)fr...

MySQL/MariaDB中如何支持全部的Unicode?

永远不要在MySQL中使用utf8,并且始终使用utf8mb4。utf8mb4介绍MySQL/MariaDB中,utf8字符集并不是对Unicode的真正实现,即不是真正的UTF-8编码,因...

聊聊 MySQL Server 可执行注释,你懂了吗?

前言MySQLServer当前支持如下3种注释风格:...

MySQL系列-源码编译安装(v5.7.34)

一、系统环境要求...

MySQL的锁就锁住我啦!与腾讯大佬的技术交谈,是我小看它了

对酒当歌,人生几何!朝朝暮暮,唯有己脱。苦苦寻觅找工作之间,殊不知今日之事乃我心之痛,难道是我不配拥有工作嘛。自面试后他所谓的等待都过去一段时日,可惜在下京东上的小金库都要见低啦。每每想到不由心中一...

MySQL字符问题_mysql中字符串的位置

中文写入乱码问题:我输入的中文编码是urf8的,建的库是urf8的,但是插入mysql总是乱码,一堆"???????????????????????"我用的是ibatis,终于找到原因了,我是这么解决...

深圳尚学堂:mysql基本sql语句大全(三)

数据开发-经典1.按姓氏笔画排序:Select*FromTableNameOrderByCustomerNameCollateChinese_PRC_Stroke_ci_as//从少...

MySQL进行行级锁的?一会next-key锁,一会间隙锁,一会记录锁?

大家好,是不是很多人都对MySQL加行级锁的规则搞的迷迷糊糊,一会是next-key锁,一会是间隙锁,一会又是记录锁。坦白说,确实还挺复杂的,但是好在我找点了点规律,也知道如何如何用命令分析加...

一文讲清怎么利用Python Django实现Excel数据表的导入导出功能

摘要:Python作为一门简单易学且功能强大的编程语言,广受程序员、数据分析师和AI工程师的青睐。本文系统讲解了如何使用Python的Django框架结合openpyxl库实现Excel...

用DataX实现两个MySQL实例间的数据同步

DataXDataX使用Java实现。如果可以实现数据库实例之间准实时的...

MySQL数据库知识_mysql数据库基础知识

MySQL是一种关系型数据库管理系统;那废话不多说,直接上自己以前学习整理文档:查看数据库命令:(1).查看存储过程状态:showprocedurestatus;(2).显示系统变量:show...

如何为MySQL中的JSON字段设置索引

背景MySQL在2015年中发布的5.7.8版本中首次引入了JSON数据类型。自此,它成了一种逃离严格列定义的方式,可以存储各种形状和大小的JSON文档,例如审计日志、配置信息、第三方数据包、用户自定...

取消回复欢迎 发表评论: