Pytorch - 手写Allreduce分布式训练
ztj100 2025-07-20 00:01 16 浏览 0 评论
1 介绍
近些年随着深度学习的火爆,模型的参数规模也飞速增长,OpenAI数据显示:
- 2012年以前,模型计算耗时每2年增长一倍,和摩尔定律保持一致;
- 2012年后,模型计算耗时每3.4个月翻一倍,远超硬件发展速度;
近一年来,百亿、千亿级的参数模型陆续面世,谷歌、英伟达、阿里、智源研究院更是发布了万亿参数模型。因此,大模型已经成为了未来深度学习的趋势。提到大模型,就不得不提分布式训练,由于模型参数和训练数据的不断增多,只有通过分布式训练才能完成大模型的训练任务。
分布式训练可以分为数据并行、模型并行,流水线并行和混合并行。分布式算法又有典型的parameter server和ring all-reduce。无论是哪一种分布式技术一个核心的关键就是如何进行communication,这是实现分布式训练的基础,因此要想掌握分布式训练或当前流行的大模型训练务必对worker间的通信方式有所了解。
互联网上已经有很多关于分布式训练的通信方面的文章,但是均没有代码层面的例子。我是属于比较愚钝类型的,只有通过自己手动实现一下方能对一些抽象的概念有较深的理解。因此,上一篇Pytorch - 分布式通信原语通过pytorch中的分布式原语库来介绍每个通信原语的行为表现,本篇文章将介绍如何在这些原语上实现分布式训练。
2 整体流程
手动数据并行的分布式训练,整体流程如下:
- 数据处理:将数据按照rank进行分片,每个rank读取对应的partition;
- 模型训练:模型构建、forward、loss和backward均与单机相同,不同的是在进行梯度更新之前调用我们自定义的average_gradients 函数进行所有rank间的梯度同步,同步完成之后再调用optimize的step接口进行梯度的更新;
- 调试执行:启动一个单机2 rank的DDP训练任务;
3 数据处理
3.1 构建数据集
构建通过pytorch中提供的torchvision DataSet来创建MNIST数据集;
- 参数root为数据下载的目录;
- 参数train指明当前创建的DataSet使用的是MNIST的训练集还是测试集;
- 参数download指明是否进行数据的下载;
- 参数transform指明DataSet中数据的变化方式
- toTensor() 将数据转换为tensor表示的形式;
- Normalize是对数据进行归一化;
dataset = datasets.MNIST(
root='./data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
]))
3.2 数据切分
通过DataPartitioner对象对dataset进行切分,执行逻辑如下:
- 构建阶段:将dataset中的数据随机按照sizes的比例分配到不同的partition中
- 返回阶段:返回参数partition指定的对应数据分片
class DataPartitioner(object):
""" Partitions a dataset into different chuncks. """
# 先对index进行shuffle
# 然后按照size进行partition
def __init__(self, data, sizes=[0.7, 0.2, 0.1], seed=1234):
self.data = data
self.partitions = []
rng = Random()
rng.seed(seed)
data_len = len(data)
indexes = [x for x in range(0, data_len)]
rng.shuffle(indexes)
for frac in sizes:
part_len = int(frac * data_len)
self.partitions.append(indexes[0:part_len])
indexes = indexes[part_len:]
def use(self, partition):
return Partition(self.data, self.partitions[partition])
通过下面的Partition对象来实现sub dataset数据的遍历
class Partition(object):
""" Dataset-like object, but only access a subset of it. """
def __init__(self, data, index):
self.data = data
self.index = index
def __len__(self):
return len(self.index)
def __getitem__(self, index):
data_idx = self.index[index]
return self.data[data_idx]
3.3 完整数据处理
def partition_dataset():
""" Partitioning MNIST """
dataset = datasets.MNIST(
'./data',
train=True,
download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307, ), (0.3081, ))
]))
size = int(dist.get_world_size()) # 获取rank的个数
total_bach_size = 128
bsz = int(total_bach_size / float(size)) # 每个rank对应的batch size
partition_sizes = [1.0 / size for _ in range(size)] # 设置每个rank处理数据量的大小
partition = DataPartitioner(dataset, partition_sizes) # 数据切分
partition = partition.use(dist.get_rank()) # 获取当前rank对应的数据
train_set = torch.utils.data.DataLoader(partition, batch_size=bsz, shuffle=True)
return train_set, bsz
4 模型训练
4.1 模型构建
由于本例是DDP(数据并行),模型被完整加载到一个GPU上,因此模型的构建单卡训练一致。
class Net(nn.Module):
""" Network architecture. """
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
4.2 训练
训练主流程同单机训练基本一致,只是在后向传播和梯度更新之间新添加个average_gradients逻辑来在所有rank之间做梯度的平均
def run(rank, size):
""" Distributed Synchronous SGD Example """
torch.manual_seed(1234)
train_set, bsz = partition_dataset()
model = Net()
model = model
model = model.cuda(rank)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
num_batches = ceil(len(train_set.dataset) / float(bsz))
for epoch in range(10):
epoch_loss = 0.0
for data, target in train_set:
data, target = Variable(data), Variable(target)
data, target = Variable(data.cuda(rank)), Variable(target.cuda(rank))
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
epoch_loss += loss.item()
loss.backward()
average_gradients(model)
optimizer.step()
print('Rank ',
dist.get_rank(), ', epoch ', epoch, ': ',
epoch_loss / num_batches)
4.3 梯度平均
梯度平均逻辑如下:
- 遍历模型中的所有参数;
- 对每个参数调用dist.all_reduce,并求平均;
- pytorch中分布式原语的使用,可以参考上一篇文章:Pytorch - 分布式通信原语(附源码)
def average_gradients(model):
""" Gradient averaging. """
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.reduce_op.SUM)
param.grad.data /= size
5 调试执行
代码执行环境:
image: pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
gpu: v100
/workspace/communication# python all_reduce_train.py
Rank 0 , epoch 0 : 1.330785048033383
Rank 1 , epoch 0 : 1.3148830299819712
Rank 0 , epoch 1 : 0.549655341136176
Rank 1 , epoch 1 : 0.5391304553317617
Rank 0 , epoch 2 : 0.43256897175871234
Rank 1 , epoch 2 : 0.42089327191238973
Rank 0 , epoch 3 : 0.37275126312714396
Rank 1 , epoch 3 : 0.3543623070409303
Rank 0 , epoch 4 : 0.31136283705801343
Rank 1 , epoch 4 : 0.3075531961868948
Rank 0 , epoch 5 : 0.29167098775982603
Rank 1 , epoch 5 : 0.2841323056836118
Rank 0 , epoch 6 : 0.26905299833556734
Rank 1 , epoch 6 : 0.26066392272520167
Rank 0 , epoch 7 : 0.25440651411885645
Rank 1 , epoch 7 : 0.2499371356706121
Rank 0 , epoch 8 : 0.2421310727260133
Rank 1 , epoch 8 : 0.2329997108149122
Rank 0 , epoch 9 : 0.22838556196199042
Rank 1 , epoch 9 : 0.2229069949927996
相关推荐
- Linux集群自动化监控系统Zabbix集群搭建到实战
-
自动化监控系统...
- systemd是什么如何使用_systemd/system
-
systemd是什么如何使用简介Systemd是一个在现代Linux发行版中广泛使用的系统和服务管理器。它负责启动系统并管理系统中运行的服务和进程。使用管理服务systemd可以用来启动、停止、...
- Linux服务器日常巡检脚本分享_linux服务器监控脚本
-
Linux系统日常巡检脚本,巡检内容包含了,磁盘,...
- 7,MySQL管理员用户管理_mysql 管理员用户
-
一、首次设置密码1.初始化时设置(推荐)mysqld--initialize--user=mysql--datadir=/data/3306/data--basedir=/usr/local...
- Python数据库编程教程:第 1 章 数据库基础与 Python 连接入门
-
1.1数据库的核心概念在开始Python数据库编程之前,我们需要先理解几个核心概念。数据库(Database)是按照数据结构来组织、存储和管理数据的仓库,它就像一个电子化的文件柜,能让我们高效...
- Linux自定义开机自启动服务脚本_linux添加开机自启动脚本
-
设置WGCloud开机自动启动服务init.d目录下新建脚本在/etc/rc.d/init.d新建启动脚本wgcloudstart.sh,内容如下...
- linux系统启动流程和服务管理,带你进去系统的世界
-
Linux启动流程Rhel6启动过程:开机自检bios-->MBR引导-->GRUB菜单-->加载内核-->init进程初始化Rhel7启动过程:开机自检BIOS-->M...
- CentOS7系统如何修改主机名_centos更改主机名称
-
请关注本头条号,每天坚持更新原创干货技术文章。如需学习视频,请在微信搜索公众号“智传网优”直接开始自助视频学习1.前言本文将讲解CentOS7系统如何修改主机名。...
- 前端工程师需要熟悉的Linux服务器(SSH 终端操作)指令
-
在Linux服务器管理中,SSH(SecureShell)是远程操作的核心工具。以下是SSH终端操作的常用命令和技巧,涵盖连接、文件操作、系统管理等场景:一、SSH连接服务器1.基本连接...
- Linux开机自启服务完全指南:3步搞定系统服务管理器配置
-
为什么需要配置开机自启?想象一下:电商服务器重启后,MySQL和Nginx没自动启动,整个网站瘫痪!这就是为什么开机自启是Linux运维的必备技能。自启服务能确保核心程序在系统启动时自动运行,避免人工...
- Kubernetes 高可用(HA)集群部署指南
-
Kubernetes高可用(HA)集群部署指南本指南涵盖从概念理解、架构选择,到kubeadm高可用部署、生产优化、监控备份和运维的全流程,适用于希望搭建稳定、生产级Kubernetes集群...
- Linux项目开发,你必须了解Systemd服务!
-
1.Systemd简介...
- Linux系统systemd服务管理工具使用技巧
-
简介:在Linux系统里,systemd就像是所有进程的“源头”,它可是系统中PID值为1的进程哟。systemd其实是一堆工具的组合,它的作用可不止是启动操作系统这么简单,像后台服务...
- Linux下NetworkManager和network的和平共处
-
简介我们在使用CentoOS系统时偶尔会遇到配置都正确但network启动不了的问题,这问题经常是由NetworkManager引起的,关闭NetworkManage并取消开机启动network就能正...
你 发表评论:
欢迎- 一周热门
-
-
MySQL中这14个小玩意,让人眼前一亮!
-
旗舰机新标杆 OPPO Find X2系列正式发布 售价5499元起
-
面试官:使用int类型做加减操作,是线程安全吗
-
C++编程知识:ToString()字符串转换你用正确了吗?
-
【Spring Boot】WebSocket 的 6 种集成方式
-
PyTorch 深度学习实战(26):多目标强化学习Multi-Objective RL
-
pytorch中的 scatter_()函数使用和详解
-
与 Java 17 相比,Java 21 究竟有多快?
-
基于TensorRT_LLM的大模型推理加速与OpenAI兼容服务优化
-
这一次,彻底搞懂Java并发包中的Atomic原子类
-
- 最近发表
-
- Linux集群自动化监控系统Zabbix集群搭建到实战
- systemd是什么如何使用_systemd/system
- Linux服务器日常巡检脚本分享_linux服务器监控脚本
- 7,MySQL管理员用户管理_mysql 管理员用户
- Python数据库编程教程:第 1 章 数据库基础与 Python 连接入门
- Linux自定义开机自启动服务脚本_linux添加开机自启动脚本
- linux系统启动流程和服务管理,带你进去系统的世界
- CentOS7系统如何修改主机名_centos更改主机名称
- 前端工程师需要熟悉的Linux服务器(SSH 终端操作)指令
- Linux开机自启服务完全指南:3步搞定系统服务管理器配置
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)
- vmware17pro最新密钥 (34)
- mysql单表最大数据量 (35)