Day234:torch中数据采样方法Sampler源码解析
ztj100 2024-11-03 16:15 15 浏览 0 评论
Sampler采样函数基类
torch.utils.data.Sampler(data_source)
- 所有采样器的基类。
- 每个采样器子类都必须提供一个__iter__()方法,这是一种遍历dataset元素索引的方法;以及一个返回迭代器长度的__len__()方法。
- pytorch中提供的采样方法主要有SequentialSampler, RandomSampler, SubsetRandomSampler, WeightedRandomSampler,关键是__iter__的实现.
下面用一个简单的例子来分析各个采样函数的源码以及
import torch
from torch.utils.data.sampler import *
import numpy as np
t = np.arange(10)
SequentialSampler顺序采样
torch.utils.data.SequentialSampler(data_source)
其中__iter__为:
iter(range(len(self.data_source)))
参数
- data_source为数据集
所以SequentialSampler的功能是顺序逐个采样数据
for i in SequentialSampler(t):
print(i,end=',')
输出:
0,1,2,3,4,5,6,7,8,9,
1
RandomSampler随机采样
torch.utils.data.RandomSampler(data_source, replacement=False, num_samples=None)
其中__iter__为:
n = len(self.data_source)
if self.replacement:
return iter(torch.randint(high=n,
size=(self.num_samples,),
dtype=torch.int64).tolist())
return iter(torch.randperm(n).tolist())
123456
参数
- data_source为数据集
- replacement:是否为有放回取样
RandomSampler当replacement开关关闭时,返回原始数据集长度下标数组随机打乱后采样值, 而当replacment开关打开后,则根据num_samples长度来生成采样序列长度。
具体可见如下代码,在replacement=False时,RandomSampler对数组t下标随机打乱输出,迭代器长度与源数据长度一致。
当replacement=True并设定num_samples=20,这时迭代器长度大于源数据,故会出现重复值。
t = np.arange(10)
for i in RandomSampler(t):
print(i,end=',')
输出:
4,5,6,0,8,1,7,9,2,3,
输入
for i in RandomSampler(t,replacement=True,num_samples=20):
print(i,end=',')
输出:
8,0,4,6,4,0,1,5,3,1,6,8,9,0,4,7,0,8,7,4,
SubsetRandomSampler索引随机采样
torch.utils.data.SubsetRandomSampler(indices)
其中__iter__为:
(self.indices[i] for i in torch.randperm(len(self.indices)))
其中
- torch.randperm对数组随机排序
- indices为给定的下标数组
所以SubsetRandomSampler的功能是在给定一个数据集下标后,对该下标数组随机排序,然后不放回取样
for i in SubsetRandomSampler(t):
print(i,end=',')
输出:
2,6,1,7,4,3,0,5,8,9,
WeightedRandomSampler加权随机采样
torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True)
其中__iter__为:
iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
其中
- weights为index权重,权重越大的取到的概率越高
- num_samples: 生成的采样长度
- replacement:是否为有放回取样
- multinomial: 伯努利随机数生成函数,也就是根据概率设定生成{0,1,…,n}
weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
torch.multinomial(weights,1,replacement=False)
输出:
tensor([1])
weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
torch.multinomial(weights,2,replacement=False)
输出:
tensor([2, 1])
weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
torch.multinomial(weights,3,replacement=False)
输出:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-41-c641212fcbc8> in <module>
1 weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
----> 2 torch.multinomial(weights,3,replacement=False)
RuntimeError: invalid argument 2: invalid multinomial distribution (with replacement=False, not enough non-negative category to sample) at /opt/conda/conda-bld/pytorch_1565287148058/work/aten/src/TH/generic/THTensorRandom.cpp:378
weights = torch.tensor([0, 10, 3, 0], dtype=torch.float)
torch.multinomial(weights,2,replacement=True)
输出:
tensor([1, 1])
weights = torch.tensor([1, 10, 3, 0], dtype=torch.float)
torch.multinomial(weights,10,replacement=True)
输出:
tensor([1, 1, 0, 0, 1, 0, 2, 1, 2, 1])
通过上面几个例子可以看出,权重值为0的index不会被取到。
当不放回取样时,replacement=False,若num_samplers小于输入数组中权重非零值个数,那么非零权重大小基本不起什么作用,反正所有的值都会取到一次
当放回取样时,权重越大的取到的概率越高。
BatchSampler批采样
torch.utils.data.BatchSampler(sampler, batch_size, drop_last)
其中__iter__为:
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
其中
- drop_last为布尔类型值,当其为真时,如果数据集长度不是batch_size整数倍时,最后一批数据将会丢弃。
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
代码中例子很清晰,数据总长度为10,如果drop_last设置为False,那么最后余下的一个作为新的batch.
原文:https://blog.csdn.net/u010137742/article/details/100996937
相关推荐
- Vue 技术栈(全家桶)(vue technology)
-
Vue技术栈(全家桶)尚硅谷前端研究院第1章:Vue核心Vue简介官网英文官网:https://vuejs.org/中文官网:https://cn.vuejs.org/...
- vue 基础- nextTick 的使用场景(vue的nexttick这个方法有什么用)
-
前言《vue基础》系列是再次回炉vue记的笔记,除了官网那部分知识点外,还会加入自己的一些理解。(里面会有部分和官网相同的文案,有经验的同学择感兴趣的阅读)在开发时,是不是遇到过这样的场景,响应...
- vue3 组件初始化流程(vue组件初始化顺序)
-
学习完成响应式系统后,咋们来看看vue3组件的初始化流程既然是看vue组件的初始化流程,咋们先来创建基本的代码,跑跑流程(在app.vue中写入以下内容,来跑流程)...
- vue3优雅的设置element-plus的table自动滚动到底部
-
场景我是需要在table最后添加一行数据,然后把滚动条滚动到最后。查网上的解决方案都是读取html结构,暴力的去获取,虽能解决问题,但是不喜欢这种打补丁的解决方案,我想着官方应该有相关的定义,于是就去...
- Vue3为什么推荐使用ref而不是reactive
-
为什么推荐使用ref而不是reactivereactive本身具有很大局限性导致使用过程需要额外注意,如果忽视这些问题将对开发造成不小的麻烦;ref更像是vue2时代optionapi的data的替...
- 9、echarts 在 vue 中怎么引用?(必会)
-
首先我们初始化一个vue项目,执行vueinitwebpackechart,接着我们进入初始化的项目下。安装echarts,npminstallecharts-S//或...
- 无所不能,将 Vue 渲染到嵌入式液晶屏
-
该文章转载自公众号@前端时刻,https://mp.weixin.qq.com/s/WDHW36zhfNFVFVv4jO2vrA前言...
- vue-element-admin 增删改查(五)(vue-element-admin怎么用)
-
此篇幅比较长,涉及到的小知识点也比较多,一定要耐心看完,记住学东西没有耐心可不行!!!一、添加和修改注:添加和编辑用到了同一个组件,也就是此篇文章你能学会如何封装组件及引用组件;第二能学会async和...
- 最全的 Vue 面试题+详解答案(vue面试题知识点大全)
-
前言本文整理了...
- 基于 vue3.0 桌面端朋友圈/登录验证+60s倒计时
-
今天给大家分享的是Vue3聊天实例中的朋友圈的实现及登录验证和倒计时操作。先上效果图这个是最新开发的vue3.x网页端聊天项目中的朋友圈模块。用到了ElementPlus...
- 不来看看这些 VUE 的生命周期钩子函数?| 原力计划
-
作者|huangfuyk责编|王晓曼出品|CSDN博客VUE的生命周期钩子函数:就是指在一个组件从创建到销毁的过程自动执行的函数,包含组件的变化。可以分为:创建、挂载、更新、销毁四个模块...
- Vue3.5正式上线,父传子props用法更丝滑简洁
-
前言Vue3.5在2024-09-03正式上线,目前在Vue官网显最新版本已经是Vue3.5,其中主要包含了几个小改动,我留意到日常最常用的改动就是props了,肯定是用Vue3的人必用的,所以针对性...
- Vue 3 生命周期完整指南(vue生命周期及使用)
-
Vue2和Vue3中的生命周期钩子的工作方式非常相似,我们仍然可以访问相同的钩子,也希望将它们能用于相同的场景。...
- 救命!这 10 个 Vue3 技巧藏太深了!性能翻倍 + 摸鱼神器全揭秘
-
前端打工人集合!是不是经常遇到这些崩溃瞬间:Vue3项目越写越卡,组件通信像走迷宫,复杂逻辑写得脑壳疼?别慌!作为在一线摸爬滚打多年的老前端,今天直接甩出10个超实用的Vue3实战技巧,手把...
- 怎么在 vue 中使用 form 清除校验状态?
-
在Vue中使用表单验证时,经常需要清除表单的校验状态。下面我将介绍一些方法来清除表单的校验状态。1.使用this.$refs...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- Vue 技术栈(全家桶)(vue technology)
- vue 基础- nextTick 的使用场景(vue的nexttick这个方法有什么用)
- vue3 组件初始化流程(vue组件初始化顺序)
- vue3优雅的设置element-plus的table自动滚动到底部
- Vue3为什么推荐使用ref而不是reactive
- 9、echarts 在 vue 中怎么引用?(必会)
- 无所不能,将 Vue 渲染到嵌入式液晶屏
- vue-element-admin 增删改查(五)(vue-element-admin怎么用)
- 最全的 Vue 面试题+详解答案(vue面试题知识点大全)
- 基于 vue3.0 桌面端朋友圈/登录验证+60s倒计时
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- node卸载 (33)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- exceptionininitializererror (33)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)