多头注意力机制是什么(Python实现+解析)
多头注意力机制是 Transformer 模型中的关键组件,通过并行地计算多个自注意力层,模型可以在不同的子空间中捕捉序列中多层次的关系特征。
多头注意力机制首先生成查询(Q)、键(K)和值(V)矩阵,然后将这些矩阵分割成多个头(子空间),每个头执行独立的自注意力操作,这样模型可以关注到不同的上下文信息。最后,拼接各个头的输出并通过线性变换得到最终输出。
多头注意力机制架构如下图所示:

图 1 多头注意力机制
下面的代码将实现多头注意力机制,并展示其在捕捉序列依赖关系方面的作用。
代码运行结果如下:
结果解析如下:
多头注意力机制首先生成查询(Q)、键(K)和值(V)矩阵,然后将这些矩阵分割成多个头(子空间),每个头执行独立的自注意力操作,这样模型可以关注到不同的上下文信息。最后,拼接各个头的输出并通过线性变换得到最终输出。
多头注意力机制架构如下图所示:

图 1 多头注意力机制
下面的代码将实现多头注意力机制,并展示其在捕捉序列依赖关系方面的作用。
import torch import torch.nn as nn import torch.nn.functional as F # 设置随机种子 torch.manual_seed(42) # 定义多头注意力机制 class MultiHeadAttention(nn.Module): def __init__(self, embed_size, heads): super(MultiHeadAttention, self).__init__() self.embed_size = embed_size self.heads = heads self.head_dim = embed_size // heads assert ( self.head_dim * heads == embed_size ), "Embedding size需要是heads的整数倍" # 线性变换生成Q、K、V矩阵 self.values = nn.Linear(self.head_dim, self.head_dim, bias=False) self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False) self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False) self.fc_out = nn.Linear(heads * self.head_dim, embed_size) def forward(self, values, keys, query, mask): N = query.shape[0] value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] # 分头计算Q、K、V矩阵 values = values.view(N, value_len, self.heads, self.head_dim) keys = keys.view(N, key_len, self.heads, self.head_dim) queries = query.view(N, query_len, self.heads, self.head_dim) values = self.values(values) keys = self.keys(keys) queries = self.queries(queries) # 计算Q与K的点积除以缩放因子 energy = torch.einsum( "nqhd,nkhd->nhqk", [queries, keys]) / (self.head_dim ** 0.5) if mask is not None: energy = energy.masked_fill(mask == 0, float("-1e20")) # 计算注意力权重 attention = torch.softmax(energy, dim=-1) # 注意力权重乘以V out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape( N, query_len, self.heads * self.head_dim ) return self.fc_out(out) # 设置参数 embed_size = 128 # 嵌入维度 heads = 8 # 多头数量 seq_length = 10 # 序列长度 batch_size = 2 # 批大小 # 创建随机输入 values = torch.rand((batch_size, seq_length, embed_size)) keys = torch.rand((batch_size, seq_length, embed_size)) queries = torch.rand((batch_size, seq_length, embed_size)) # 初始化多头注意力层 multi_head_attention_layer = MultiHeadAttention(embed_size, heads) # 前向传播 output = multi_head_attention_layer(values, keys, queries, mask=None) print("输出的形状:", output.shape) print("多头注意力机制的输出:\n", output) # 进一步展示注意力权重计算 class MultiHeadAttentionWithWeights(MultiHeadAttention): def forward(self, values, keys, query, mask): N = query.shape[0] value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1] values = values.view(N, value_len, self.heads, self.head_dim) keys = keys.view(N, key_len, self.heads, self.head_dim) queries = query.view(N, query_len, self.heads, self.head_dim) values = self.values(values) keys = self.keys(keys) queries = self.queries(queries) energy = torch.einsum( "nqhd,nkhd->nhqk", [queries, keys]) / (self.head_dim ** 0.5) if mask is not None: energy = energy.masked_fill(mask == 0, float("-1e20")) attention = torch.softmax(energy, dim=-1) out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape( N, query_len, self.heads * self.head_dim ) return self.fc_out(out), attention # 使用带有权重输出的多头注意力层 multi_head_attention_with_weights = MultiHeadAttentionWithWeights(embed_size, heads) output, attention_weights = multi_head_attention_with_weights(values, keys, queries, mask=None) print("注意力权重形状:", attention_weights.shape) print("注意力权重:\n", attention_weights)代码解析如下:
- MultiHeadAttention:实现多头注意力机制,首先生成查询(Q)、键(K)和值(V)矩阵,并将这些矩阵分成多个头,然后计算每个头的自注意力,最后拼接各个头的结果并通过线性变换得到最终输出;
- 前向传播:多头注意力层的输入为查询、键和值矩阵,输出为多头注意力的结果,该结果在多个子空间中并行捕捉序列依赖信息;
- MultiHeadAttentionWithWeights:在多头注意力基础上进一步返回注意力权重,用于分析模型在不同位置的注意力分布情况。
代码运行结果如下:
输出的形状:torch.Size([2, 10, 128]) 多头注意力机制的输出: tensor([[[ 0.123, -0.346, ... ], [ 0.765, 0.245, ... ], ... ]]) 注意力权重形状:torch.Size([2, 8, 10, 10]) 注意力权重: tensor([[[[0.125, 0.063, ... ], [0.078, 0.032, ... ], ... ]]])
结果解析如下:
- 输出的形状:多头注意力机制的输出形状为[batch_size, seq_length, embed_size],与输入形状一致,这样可确保输出直接接入后续层;
- 注意力权重:展示注意力在不同位置间的分布,通过多头机制,模型可以在多个子空间中捕捉丰富的序列间依赖关系。