0%

关于Numpy和Pytorch中 dot,multiply,@, * 的运算区别

Numpy 中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np

a = np.matrix([[1,2],
[4,5]])
v = np.matrix([[7],
[8]])

c = np.dot(a,v) # 向量点积或矩阵乘法,shape不合法时报错
print(c)

d = np.multiply(a,v) # 先广播后对位相乘,shape无法广播时报错
print(d)

e = a @ v # 向量点积或矩阵乘法,shape不合法时报错
print(e)

f = a * v # array:先广播后对位相乘;matrix:矩阵乘法,shape不合法时报错
print(f)

output:
[[23]
[68]]
[[ 7 14]
[32 40]]
[[23]
[68]]
[[23]
[68]]

Pytorch 中:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
a = torch.Tensor([[1,2],
[4,5]])

v = torch.Tensor([[7],
[8]])

# c = torch.dot(a,v) # 会报错,torch.dot 有意仅支持计算两个具有相同数量元素的 1D 张量的点积
# print(c)

d = torch.multiply(a,v) # 先广播后对位相乘,shape无法广播时报错
print(d)

e = a @ v # 向量点积或矩阵乘法,shape不合法时报错
print(e)

f = a * v # 先广播后对位相乘
print(f)

output:
tensor([[ 7., 14.],
[32., 40.]])
tensor([[23.],
[68.]])
tensor([[ 7., 14.],
[32., 40.]])