1. 理解注意力¶

  • 请在“#TODO”注释后直接填写您的答案。
  • 要提交作业,请将此笔记本下载为 Python 文件,名为“A2S1.py”。

Imports and Setup¶

In [ ]:
import torch
import torch.nn as nn
import torch.nn.functional as F
In [ ]:
torch.manual_seed(447)

key = torch.randn(4, 3)
key /= torch.norm(key, dim=1, keepdim=True)
key.round_(decimals=2)

value = torch.randn(4, 3)
value /= torch.norm(value, dim=1, keepdim=True)
value.round_(decimals=2)

print(f'key:\n{key}')
print(f'value:\n{value}')
key:
tensor([[ 0.4700,  0.6500,  0.6000],
        [ 0.6400,  0.5000, -0.5900],
        [-0.0300, -0.4800, -0.8800],
        [ 0.4300, -0.8300,  0.3500]])
value:
tensor([[-0.0700, -0.8800,  0.4700],
        [ 0.3700, -0.9300, -0.0700],
        [-0.2500, -0.7500,  0.6100],
        [ 0.9400,  0.2000,  0.2800]])
In [ ]:
def attention(query, key, value):
    """
    Note that we remove scaling for simplicity.
    """
    return F.scaled_dot_product_attention(query, key, value, scale=1)


def check_query(query, target, key=key, value=value):
    """
    Helper function for you to check if your query is close to the required target matrix.
    """
    a_out = attention(query, key, value)
    print("maximum absolute element-wise difference:", (target - a_out).abs().max())

1.2. Selection via Attention¶

In [ ]:
# Define a query vector to ”select” the first value vector

# TODO:
query121 =

# compare output of attention with desired output
print(query121)
check_query(query121, value[0])
In [ ]:
# Define a query matrix which results in an identity mapping – select all the value vectors

# TODO:
query122 =

# compare output of attention with desired output
print(query122)
check_query(query122, value)

1.3. Averaging via Attention¶

In [ ]:
# define a query vector which averages all the value vectors

# TODO:
query131 =

# compare output of attention with desired output
print(query131)
target = torch.reshape(value.mean(0, keepdims=True), (3,))  # reshape to a vector
check_query(query131, target)
In [ ]:
# define a query vector which averages the first two value vectors

# TODO:
query132 =

# compare output of attention with desired output
print(query132)
target = torch.reshape(value[(0, 1),].mean(0, keepdims=True), (3,))  # reshape to a vector
check_query(query132, target)

1.4. Interactions within Attention¶

In [ ]:
# Define a replacement for only the third key vector k[2] such that the result of attention
# with the same unchanged query q from (1.3.2) averages the first three value vectors.
m_key = key.clone()

# TODO:
m_key[2] =

# compare output of attention with desired output
check_query(query132, value[(0, 1, 2),].mean(0, keepdims=True), key=m_key)
In [ ]:
# Define a replacement for only the third key vector k[2] such that the result of attention
# with the same unchanged query q from (1.3.2) returns the third value vector v[2].
m_key = key.clone()

# TODO:
m_key[2] =
m_key[2] /= m_key[2].norm()

# compare output of attention with desired output
check_query(query132, value[2], key=m_key)