Day236:addmm()和addmm_()的用法详解
ztj100 2024-11-03 16:15 12 浏览 0 评论
函数解释
在torch/_C/_VariableFunctions.py的有该定义,意义就是实现一下公式:
换句话说,就是需要传入5个参数,mat里的每个元素乘以beta,mat1和mat2进行矩阵乘法(左行乘右列)后再乘以alpha,最后将这2个结果加在一起。但是这样说可能没啥概念,接下来博主为大家写上一段代码,大家就明白了~
def addmm(self, beta=1, mat, alpha=1, mat1, mat2, out=None): # real signature unknown; restored from __doc__
"""
addmm(beta=1, mat, alpha=1, mat1, mat2, out=None) -> Tensor
Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`.
The matrix :attr:`mat` is added to the final result.
If :attr:`mat1` is a :math:`(n \times m)` tensor, :attr:`mat2` is a
:math:`(m \times p)` tensor, then :attr:`mat` must be
:ref:`broadcastable <broadcasting-semantics>` with a :math:`(n \times p)` tensor
and :attr:`out` will be a :math:`(n \times p)` tensor.
:attr:`alpha` and :attr:`beta` are scaling factors on matrix-vector product between
:attr:`mat1` and :attr`mat2` and the added matrix :attr:`mat` respectively.
.. math::
out = \beta\ mat + \alpha\ (mat1_i \mathbin{@} mat2_i)
For inputs of type `FloatTensor` or `DoubleTensor`, arguments :attr:`beta` and
:attr:`alpha` must be real numbers, otherwise they should be integers.
Args:
beta (Number, optional): multiplier for :attr:`mat` (:math:`\beta`)
mat (Tensor): matrix to be added
alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`)
mat1 (Tensor): the first matrix to be multiplied
mat2 (Tensor): the second matrix to be multiplied
out (Tensor, optional): the output tensor
Example::
>>> M = torch.randn(2, 3)
>>> mat1 = torch.randn(2, 3)
>>> mat2 = torch.randn(3, 3)
>>> torch.addmm(M, mat1, mat2)
tensor([[-4.8716, 1.4671, -1.3746],
[ 0.7573, -3.9555, -2.8681]])
"""
pass
代码范例
1.先摆出代码,大家可以先复制粘贴运行一下,在之后会一一讲解
"""
@author:nickhuang1996
"""
import torch
rectangle_height = 3
rectangle_width = 3
inputs = torch.randn(rectangle_height, rectangle_width)
for i in range(rectangle_height):
for j in range(rectangle_width):
inputs[i] = i * torch.ones(rectangle_width)
'''
inputs and its transpose
-->inputs = tensor([[0., 0., 0.],
[1., 1., 1.],
[2., 2., 2.]])
-->inputs_t = tensor([[0., 1., 2.],
[0., 1., 2.],
[0., 1., 2.]])
'''
print("inputs:\n", inputs)
inputs_t = inputs.t()
print("inputs_t:\n", inputs_t)
'''
inputs_t @ inputs_t [[0., 1., 2.], [[0., 1., 2.], [[0., 3., 6.]
= [0., 1., 2.], @ [0., 1., 2.], = [0., 3., 6.]
[0., 1., 2.]] [0., 1., 2.]] [0., 3., 6.]]
'''
'''a, b, c and d = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
a = torch.addmm(input=inputs, mat1=inputs_t, mat2=inputs_t)
b = inputs.addmm(mat1=inputs_t, mat2=inputs_t)
c = torch.addmm(input=inputs, beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
d = inputs.addmm(beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
'''e and f = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
e = torch.addmm(inputs, inputs_t, inputs_t)
f = inputs.addmm(inputs_t, inputs_t)
'''1 * inputs + 1 * (inputs_t @ inputs_t)'''
g = inputs.addmm(1, inputs_t, inputs_t)
'''2 * inputs + 1 * (inputs_t @ inputs_t)'''
g2 = inputs.addmm(2, inputs_t, inputs_t)
'''h = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
h = inputs.addmm(1, 1, inputs_t, inputs_t)
'''h12 = 1 * inputs + 2 * (inputs_t @ inputs_t)'''
h12 = inputs.addmm(1, 2, inputs_t, inputs_t)
'''h21 = 2 * inputs + 1 * (inputs_t @ inputs_t)'''
h21 = inputs.addmm(2, 1, inputs_t, inputs_t)
print("a:\n", a)
print("b:\n", b)
print("c:\n", c)
print("d:\n", d)
print("e:\n", e)
print("f:\n", f)
print("g:\n", g)
print("g2:\n", g2)
print("h:\n", h)
print("h12:\n", h12)
print("h21:\n", h21)
print("inputs:\n", inputs)
'''inputs = 1 * inputs - 2 * (inputs @ inputs_t)'''
'''
inputs @ inputs_t [[0., 0., 0.], [[0., 1., 2.], [[0., 0., 0.]
= [1., 1., 1.], @ [0., 1., 2.], = [0., 3., 6.]
[2., 2., 2.]] [0., 1., 2.]] [0., 6., 12.]]
'''
inputs.addmm_(1, -2, inputs, inputs_t) # In-place
print("inputs:\n", inputs)
2.其中
inputs是一个3×3的矩阵,为
tensor([[0., 0., 0.],
[1., 1., 1.],
[2., 2., 2.]])
inputs_t也是一个3×3的矩阵,是inputs的转置矩阵,为
tensor([[0., 1., 2.],
[0., 1., 2.],
[0., 1., 2.]])
* inputs_t @ inputs_t为
'''
inputs_t @ inputs_t [[0., 1., 2.], [[0., 1., 2.], [[0., 3., 6.]
= [0., 1., 2.], @ [0., 1., 2.], = [0., 3., 6.]
[0., 1., 2.]] [0., 1., 2.]] [0., 3., 6.]]
'''
3.代码中a,b,c和d展示的是完全形式,即标明了位置参数和传入参数。可以看到input这个位置参数可以写在函数的前面,即
torch.addmm(input, mat1, mat2) = inputs.addmm(mat1, mat2)
完成的公式为:
1 × inputs + 1 ×(inputs_t @ inputs_t)
'''a, b, c and d = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
a = torch.addmm(input=inputs, mat1=inputs_t, mat2=inputs_t)
b = inputs.addmm(mat1=inputs_t, mat2=inputs_t)
c = torch.addmm(input=inputs, beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
d = inputs.addmm(beta=1, mat1=inputs_t, mat2=inputs_t, alpha=1)
a:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
b:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
c:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
d:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
4.下面的例子更好了说明了input参数的位置可变性,并且beta和alpha都缺省了:
完成的公式为:
1 × inputs + 1 ×(inputs_t @ inputs_t)
'''e and f = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
e = torch.addmm(inputs, inputs_t, inputs_t)
f = inputs.addmm(inputs_t, inputs_t)
e:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
f:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
5.加一个参数,实际上是添加了beta这个参数
完成的公式为:
g = 1 × inputs + 1 ×(inputs_t @ inputs_t)
g2 = 2 × inputs + 1 ×(inputs_t @ inputs_t)
'''1 * inputs + 1 * (inputs_t @ inputs_t)'''
g = inputs.addmm(1, inputs_t, inputs_t)
'''2 * inputs + 1 * (inputs_t @ inputs_t)'''
g2 = inputs.addmm(2, inputs_t, inputs_t)
g:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
g2:
tensor([[ 0., 3., 6.],
[ 2., 5., 8.],
[ 4., 7., 10.]])
6.再加一个参数,实际上是添加了alpha这个参数
完成的公式为:
h = 1 × inputs + 1 ×(inputs_t @ inputs_t)
h12 = 1 × inputs + 2 ×(inputs_t @ inputs_t)
h21 = 2 × inputs + 1 ×(inputs_t @ inputs_t)
'''h = 1 * inputs + 1 * (inputs_t @ inputs_t)'''
h = inputs.addmm(1, 1, inputs_t, inputs_t)
'''h12 = 1 * inputs + 2 * (inputs_t @ inputs_t)'''
h12 = inputs.addmm(1, 2, inputs_t, inputs_t)
'''h21 = 2 * inputs + 1 * (inputs_t @ inputs_t)'''
h21 = inputs.addmm(2, 1, inputs_t, inputs_t)
h:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
h12:
tensor([[ 0., 6., 12.],
[ 1., 7., 13.],
[ 2., 8., 14.]])
h21:
tensor([[ 0., 3., 6.],
[ 2., 5., 8.],
[ 4., 7., 10.]])
7.当然,以上的步骤inputs没有变化,还是为
inputs:
tensor([[0., 0., 0.],
[1., 1., 1.],
[2., 2., 2.]])
*8.addmm_()的操作和addmm()函数功能相同,区别就是addmm_()有inplace的操作,也就是在原对象基础上进行修改,即把改变之后的变量再赋给原来的变量。例如:
inputs的值变成了改变之后的值,不用再去写 某个变量=addmm_() 了,因为inputs就是改变之后的变量!
*inputs@ inputs_t为
'''
inputs @ inputs_t [[0., 0., 0.], [[0., 1., 2.], [[0., 0., 0.]
= [1., 1., 1.], @ [0., 1., 2.], = [0., 3., 6.]
[2., 2., 2.]] [0., 1., 2.]] [0., 6., 12.]]
'''
完成的公式为:
inputs = 1 × inputs - 2 ×(inputs @ inputs_t)
'''inputs = 1 * inputs - 2 * (inputs @ inputs_t)'''
inputs.addmm_(1, -2, inputs, inputs_t) # In-place
inputs:
tensor([[ 0., 0., 0.],
[ 1., -5., -11.],
[ 2., -10., -22.]])
三、代码运行结果
inputs:
tensor([[0., 0., 0.],
[1., 1., 1.],
[2., 2., 2.]])
inputs_t:
tensor([[0., 1., 2.],
[0., 1., 2.],
[0., 1., 2.]])
a:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
b:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
c:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
d:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
e:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
f:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
g:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
g2:
tensor([[ 0., 3., 6.],
[ 2., 5., 8.],
[ 4., 7., 10.]])
h:
tensor([[0., 3., 6.],
[1., 4., 7.],
[2., 5., 8.]])
h12:
tensor([[ 0., 6., 12.],
[ 1., 7., 13.],
[ 2., 8., 14.]])
h21:
tensor([[ 0., 3., 6.],
[ 2., 5., 8.],
[ 4., 7., 10.]])
inputs:
tensor([[0., 0., 0.],
[1., 1., 1.],
[2., 2., 2.]])
inputs:
tensor([[ 0., 0., 0.],
[ 1., -5., -11.],
[ 2., -10., -22.]])
原文:https://blog.csdn.net/qq_36556893/article/details/90638449
相关推荐
- Vue 技术栈(全家桶)(vue technology)
-
Vue技术栈(全家桶)尚硅谷前端研究院第1章:Vue核心Vue简介官网英文官网:https://vuejs.org/中文官网:https://cn.vuejs.org/...
- vue 基础- nextTick 的使用场景(vue的nexttick这个方法有什么用)
-
前言《vue基础》系列是再次回炉vue记的笔记,除了官网那部分知识点外,还会加入自己的一些理解。(里面会有部分和官网相同的文案,有经验的同学择感兴趣的阅读)在开发时,是不是遇到过这样的场景,响应...
- vue3 组件初始化流程(vue组件初始化顺序)
-
学习完成响应式系统后,咋们来看看vue3组件的初始化流程既然是看vue组件的初始化流程,咋们先来创建基本的代码,跑跑流程(在app.vue中写入以下内容,来跑流程)...
- vue3优雅的设置element-plus的table自动滚动到底部
-
场景我是需要在table最后添加一行数据,然后把滚动条滚动到最后。查网上的解决方案都是读取html结构,暴力的去获取,虽能解决问题,但是不喜欢这种打补丁的解决方案,我想着官方应该有相关的定义,于是就去...
- Vue3为什么推荐使用ref而不是reactive
-
为什么推荐使用ref而不是reactivereactive本身具有很大局限性导致使用过程需要额外注意,如果忽视这些问题将对开发造成不小的麻烦;ref更像是vue2时代optionapi的data的替...
- 9、echarts 在 vue 中怎么引用?(必会)
-
首先我们初始化一个vue项目,执行vueinitwebpackechart,接着我们进入初始化的项目下。安装echarts,npminstallecharts-S//或...
- 无所不能,将 Vue 渲染到嵌入式液晶屏
-
该文章转载自公众号@前端时刻,https://mp.weixin.qq.com/s/WDHW36zhfNFVFVv4jO2vrA前言...
- vue-element-admin 增删改查(五)(vue-element-admin怎么用)
-
此篇幅比较长,涉及到的小知识点也比较多,一定要耐心看完,记住学东西没有耐心可不行!!!一、添加和修改注:添加和编辑用到了同一个组件,也就是此篇文章你能学会如何封装组件及引用组件;第二能学会async和...
- 最全的 Vue 面试题+详解答案(vue面试题知识点大全)
-
前言本文整理了...
- 基于 vue3.0 桌面端朋友圈/登录验证+60s倒计时
-
今天给大家分享的是Vue3聊天实例中的朋友圈的实现及登录验证和倒计时操作。先上效果图这个是最新开发的vue3.x网页端聊天项目中的朋友圈模块。用到了ElementPlus...
- 不来看看这些 VUE 的生命周期钩子函数?| 原力计划
-
作者|huangfuyk责编|王晓曼出品|CSDN博客VUE的生命周期钩子函数:就是指在一个组件从创建到销毁的过程自动执行的函数,包含组件的变化。可以分为:创建、挂载、更新、销毁四个模块...
- Vue3.5正式上线,父传子props用法更丝滑简洁
-
前言Vue3.5在2024-09-03正式上线,目前在Vue官网显最新版本已经是Vue3.5,其中主要包含了几个小改动,我留意到日常最常用的改动就是props了,肯定是用Vue3的人必用的,所以针对性...
- Vue 3 生命周期完整指南(vue生命周期及使用)
-
Vue2和Vue3中的生命周期钩子的工作方式非常相似,我们仍然可以访问相同的钩子,也希望将它们能用于相同的场景。...
- 救命!这 10 个 Vue3 技巧藏太深了!性能翻倍 + 摸鱼神器全揭秘
-
前端打工人集合!是不是经常遇到这些崩溃瞬间:Vue3项目越写越卡,组件通信像走迷宫,复杂逻辑写得脑壳疼?别慌!作为在一线摸爬滚打多年的老前端,今天直接甩出10个超实用的Vue3实战技巧,手把...
- 怎么在 vue 中使用 form 清除校验状态?
-
在Vue中使用表单验证时,经常需要清除表单的校验状态。下面我将介绍一些方法来清除表单的校验状态。1.使用this.$refs...
你 发表评论:
欢迎- 一周热门
- 最近发表
-
- Vue 技术栈(全家桶)(vue technology)
- vue 基础- nextTick 的使用场景(vue的nexttick这个方法有什么用)
- vue3 组件初始化流程(vue组件初始化顺序)
- vue3优雅的设置element-plus的table自动滚动到底部
- Vue3为什么推荐使用ref而不是reactive
- 9、echarts 在 vue 中怎么引用?(必会)
- 无所不能,将 Vue 渲染到嵌入式液晶屏
- vue-element-admin 增删改查(五)(vue-element-admin怎么用)
- 最全的 Vue 面试题+详解答案(vue面试题知识点大全)
- 基于 vue3.0 桌面端朋友圈/登录验证+60s倒计时
- 标签列表
-
- idea eval reset (50)
- vue dispatch (70)
- update canceled (42)
- order by asc (53)
- spring gateway (67)
- 简单代码编程 贪吃蛇 (40)
- transforms.resize (33)
- redisson trylock (35)
- 卸载node (35)
- np.reshape (33)
- torch.arange (34)
- node卸载 (33)
- npm 源 (35)
- vue3 deep (35)
- win10 ssh (35)
- exceptionininitializererror (33)
- vue foreach (34)
- idea设置编码为utf8 (35)
- vue 数组添加元素 (34)
- std find (34)
- tablefield注解用途 (35)
- python str转json (34)
- java websocket客户端 (34)
- tensor.view (34)
- java jackson (34)