矩阵相乘torch.einsum()
einsum 是 Einstein summation 的缩写,来源于爱因斯坦求和约定(Einstein summation convention)。这是物理学家阿尔伯特·爱因斯坦引入的一种简便记号,用于描述张量运算,特别是涉及多维数组的运算。
示例1:矩阵乘法
矩阵乘法 C=AB
(图片来源网络,侵删)
A = torch.randn(2, 3) B = torch.randn(3, 4) C = torch.einsum('ik,kj->ij', A, B) print(C.size()) # 输出: torch.Size([2, 4])
这里,'ik,kj->ij' 的含义是:
- A 的形状为 (2, 3),对应 ik,i 和 k 分别表示第一个和第二个维度。
- B 的形状为 (3, 4),对应 kj,k 和 j 分别表示第一个和第二个维度。
- ->ij 表示输出张量的模式,结果为 (2, 4)。
示例2:向量点积
向量点积 c=a⋅b
a = torch.randn(3) b = torch.randn(3) c = torch.einsum('i,i->', a, b) print(c.size()) # 输出: torch.Size([])
这里,'i,i->' 的含义是:
- a 和 b 都是向量,对应模式 i。
- -> 后面为空,表示结果是一个标量。
示例3:批量矩阵乘法
批量矩阵乘法
A = torch.randn(10, 2, 3) B = torch.randn(10, 3, 4) C = torch.einsum('bij,bjk->bik', A, B) print(C.size()) # 输出: torch.Size([10, 2, 4])
这里,'bij,bjk->bik' 的含义是:
- A 的形状为 (10, 2, 3),对应 bij,b 表示批次维度,i 和 j 分别表示矩阵的行和列。
- B 的形状为 (10, 3, 4),对应 bjk,b 表示批次维度,j 和 k 分别表示矩阵的行和列。
- ->bik 表示输出张量的模式,结果为 (10, 2, 4)。
示例4:逐元素相乘(哈达玛积)A.B或A × B
A = torch.randn(3, 4) B = torch.randn(3, 4) C = torch.einsum('ij,ij->ij', A, B) print(C.size()) # 输出: torch.Size([3, 4])
'ij,ij->ij' 表示:
- A 和 B 都是形状为 [3, 4] 的矩阵,用 ij 表示。
- 结果 C 也是形状为 [3, 4] 的矩阵。
- 没有重复索引,所以不进行求和。
文章版权声明:除非注明,否则均为主机测评原创文章,转载或复制请以超链接形式并注明出处。