矩阵相乘torch.einsum()

06-09 1137阅读

einsum 是 Einstein summation 的缩写,来源于爱因斯坦求和约定(Einstein summation convention)。这是物理学家阿尔伯特·爱因斯坦引入的一种简便记号,用于描述张量运算,特别是涉及多维数组的运算。

示例1:矩阵乘法

矩阵乘法 C=AB

矩阵相乘torch.einsum()
(图片来源网络,侵删)
A = torch.randn(2, 3)
B = torch.randn(3, 4)
C = torch.einsum('ik,kj->ij', A, B)
print(C.size())  # 输出: torch.Size([2, 4])

 这里,'ik,kj->ij' 的含义是:

  • A 的形状为 (2, 3),对应 ik,i 和 k 分别表示第一个和第二个维度。
  • B 的形状为 (3, 4),对应 kj,k 和 j 分别表示第一个和第二个维度。
  • ->ij 表示输出张量的模式,结果为 (2, 4)。
    示例2:向量点积

    向量点积 c=a⋅b

    a = torch.randn(3)
    b = torch.randn(3)
    c = torch.einsum('i,i->', a, b)
    print(c.size())  # 输出: torch.Size([])
    

    这里,'i,i->' 的含义是:

    • a 和 b 都是向量,对应模式 i。
    • -> 后面为空,表示结果是一个标量。
      示例3:批量矩阵乘法

      批量矩阵乘法

      A = torch.randn(10, 2, 3)
      B = torch.randn(10, 3, 4)
      C = torch.einsum('bij,bjk->bik', A, B)
      print(C.size())  # 输出: torch.Size([10, 2, 4])
      

      这里,'bij,bjk->bik' 的含义是:

      • A 的形状为 (10, 2, 3),对应 bij,b 表示批次维度,i 和 j 分别表示矩阵的行和列。
      • B 的形状为 (10, 3, 4),对应 bjk,b 表示批次维度,j 和 k 分别表示矩阵的行和列。
      • ->bik 表示输出张量的模式,结果为 (10, 2, 4)。

        示例4:逐元素相乘(哈达玛积)A.B或A × B

        A = torch.randn(3, 4)
        B = torch.randn(3, 4)
        C = torch.einsum('ij,ij->ij', A, B)
        print(C.size())  # 输出: torch.Size([3, 4])
        

        'ij,ij->ij' 表示:

        • A 和 B 都是形状为 [3, 4] 的矩阵,用 ij 表示。
        • 结果 C 也是形状为 [3, 4] 的矩阵。
        • 没有重复索引,所以不进行求和。

           

VPS购买请点击我

文章版权声明:除非注明,否则均为主机测评原创文章,转载或复制请以超链接形式并注明出处。

目录[+]