Transformer——多头注意力机制(Pytorch)

2024-07-13 1153阅读

1. 原理图Transformer——多头注意力机制(Pytorch)

2. 代码

import torch
import torch.nn as nn
class Multi_Head_Self_Attention(nn.Module):
    def __init__(self, embed_size, heads):
        super(Multi_Head_Self_Attention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads
        self.queries = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.keys = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.values = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out = nn.Linear(self.embed_size, self.embed_size, bias=False)
    def forward(self,queries, keys, values, mask):
        N = queries.shape[0]  # batch_size
        query_len = queries.shape[1]  # sequence_length
        key_len = keys.shape[1]  # sequence_length 
        value_len = values.shape[1]  # sequence_length
        queries = self.queries(queries)
        keys = self.keys(keys)
        values = self.values(values)
        # Split the embedding into self.heads pieces
        # batch_size, sequence_length, embed_size(512) --> 
        # batch_size, sequence_length, heads(8), head_dim(64)
        queries = queries.reshape(N, query_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        # batch_size, sequence_length, heads(8), head_dim(64) --> 
        # batch_size, heads(8), sequence_length, head_dim(64)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        # Scaled dot-product attention
        score = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** (1/2))
        if mask is not None:
            score = score.masked_fill(mask == 0, float("-inf"))
        # batch_size, heads(8), sequence_length, sequence_length
        attention = torch.softmax(score, dim=-1)
        out = torch.matmul(attention, values)
        # batch_size, heads(8), sequence_length, head_dim(64) -->
        # batch_size, sequence_length, heads(8), head_dim(64) -->
        # batch_size, sequence_length, embed_size(512)
        # 为了方便送入后面的网络
        out = out.transpose(1, 2).contiguous().reshape(N, query_len, self.embed_size)
        out = self.fc_out(out)
        return out
    
batch_size = 64
sequence_length = 10
embed_size = 512
heads = 8
mask = None
Q = torch.randn(batch_size, sequence_length, embed_size)  
K = torch.randn(batch_size, sequence_length, embed_size)  
V = torch.randn(batch_size, sequence_length, embed_size)  
model = Multi_Head_Self_Attention(embed_size, heads)
output = model(Q, K, V, mask)
print(output.shape)

 

VPS购买请点击我

免责声明:我们致力于保护作者版权,注重分享,被刊用文章因无法核实真实出处,未能及时与作者取得联系,或有版权异议的,请联系管理员,我们会立即处理! 部分文章是来自自研大数据AI进行生成,内容摘自(百度百科,百度知道,头条百科,中国民法典,刑法,牛津词典,新华词典,汉语词典,国家院校,科普平台)等数据,内容仅供学习参考,不准确地方联系删除处理! 图片声明:本站部分配图来自人工智能系统AI生成,觅知网授权图片,PxHere摄影无版权图库和百度,360,搜狗等多加搜索引擎自动关键词搜索配图,如有侵权的图片,请第一时间联系我们,邮箱:ciyunidc@ciyunshuju.com。本站只作为美观性配图使用,无任何非法侵犯第三方意图,一切解释权归图片著作权方,本站不承担任何责任。如有恶意碰瓷者,必当奉陪到底严惩不贷!

目录[+]