1. 理解注意力¶
- 请在“#TODO”注释后直接填写您的答案。
- 要提交作业,请将此笔记本下载为 Python 文件,名为“A2S1.py”。
Imports and Setup¶
In [ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
In [ ]:
torch.manual_seed(447)
key = torch.randn(4, 3)
key /= torch.norm(key, dim=1, keepdim=True)
key.round_(decimals=2)
value = torch.randn(4, 3)
value /= torch.norm(value, dim=1, keepdim=True)
value.round_(decimals=2)
print(f'key:\n{key}')
print(f'value:\n{value}')
key:
tensor([[ 0.4700, 0.6500, 0.6000],
[ 0.6400, 0.5000, -0.5900],
[-0.0300, -0.4800, -0.8800],
[ 0.4300, -0.8300, 0.3500]])
value:
tensor([[-0.0700, -0.8800, 0.4700],
[ 0.3700, -0.9300, -0.0700],
[-0.2500, -0.7500, 0.6100],
[ 0.9400, 0.2000, 0.2800]])
In [ ]:
def attention(query, key, value):
"""
Note that we remove scaling for simplicity.
"""
return F.scaled_dot_product_attention(query, key, value, scale=1)
def check_query(query, target, key=key, value=value):
"""
Helper function for you to check if your query is close to the required target matrix.
"""
a_out = attention(query, key, value)
print("maximum absolute element-wise difference:", (target - a_out).abs().max())
1.2. Selection via Attention¶
In [ ]:
# Define a query vector to ”select” the first value vector
# TODO:
query121 =
# compare output of attention with desired output
print(query121)
check_query(query121, value[0])
In [ ]:
# Define a query matrix which results in an identity mapping – select all the value vectors
# TODO:
query122 =
# compare output of attention with desired output
print(query122)
check_query(query122, value)
1.3. Averaging via Attention¶
In [ ]:
# define a query vector which averages all the value vectors
# TODO:
query131 =
# compare output of attention with desired output
print(query131)
target = torch.reshape(value.mean(0, keepdims=True), (3,)) # reshape to a vector
check_query(query131, target)
In [ ]:
# define a query vector which averages the first two value vectors
# TODO:
query132 =
# compare output of attention with desired output
print(query132)
target = torch.reshape(value[(0, 1),].mean(0, keepdims=True), (3,)) # reshape to a vector
check_query(query132, target)
1.4. Interactions within Attention¶
In [ ]:
# Define a replacement for only the third key vector k[2] such that the result of attention
# with the same unchanged query q from (1.3.2) averages the first three value vectors.
m_key = key.clone()
# TODO:
m_key[2] =
# compare output of attention with desired output
check_query(query132, value[(0, 1, 2),].mean(0, keepdims=True), key=m_key)
In [ ]:
# Define a replacement for only the third key vector k[2] such that the result of attention
# with the same unchanged query q from (1.3.2) returns the third value vector v[2].
m_key = key.clone()
# TODO:
m_key[2] =
m_key[2] /= m_key[2].norm()
# compare output of attention with desired output
check_query(query132, value[2], key=m_key)