This practical exercise focuses on implementing the core attention calculation, specifically the Scaled Dot-Product Attention mechanism. This mechanism is fundamental to how Transformers process information, allowing the model to weigh the importance of different elements in the input sequence relative to each other.We will implement the following formula:$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V $$This function takes Query ($Q$), Key ($K$), and Value ($V$) matrices as input, along with the dimension of the key vectors ($d_k$) for scaling. Optionally, it can also handle a mask to prevent attention to certain positions (like padding tokens or future tokens in a decoder).Setting UpWe'll use PyTorch for this implementation, but the concepts translate directly to other deep learning frameworks like TensorFlow. Ensure you have PyTorch installed. We'll also need math for the square root calculation.import torch import torch.nn.functional as F import mathImplementing the Attention FunctionLet's define a Python function scaled_dot_product_attention that performs the calculation. It will accept tensors for $Q$, $K$, $V$, and an optional mask.def scaled_dot_product_attention(query, key, value, mask=None): """ Calculates the Scaled Dot-Product Attention. Args: query (torch.Tensor): Query tensor; shape (batch_size, ..., seq_len_q, d_k) (torch.Tensor): Key tensor; shape (batch_size, ..., seq_len_k, d_k) value (torch.Tensor): Value tensor; shape (batch_size, ..., seq_len_v, d_v) Note: seq_len_k and seq_len_v must be the same. mask (torch.Tensor, optional): Mask tensor; shape must be broadcastable to (batch_size, ..., seq_len_q, seq_len_k). Defaults to None. Returns: torch.Tensor: Output tensor; shape (batch_size, ..., seq_len_q, d_v) torch.Tensor: Attention weights; shape (batch_size, ..., seq_len_q, seq_len_k) """ # Get the dimension of the vectors d_k = query.size(-1) # 1. Calculate dot products: Q * K^T # Result shape: (batch_size, ..., seq_len_q, seq_len_k) ```python attention_scores = torch.matmul(query, key.transpose(-2, -1))# 2. Scale the scores attention_scores = attention_scores / math.sqrt(d_k) # 3. Apply the mask (if provided) # The mask indicates positions to ignore (e.g., padding). # We add a large negative number (-1e9) to these positions before softmax. if mask is not None: # Ensure mask has compatible shape attention_scores = attention_scores.masked_fill(mask == 0, -1e9) # 4. Apply softmax to get attention weights # Softmax is applied on the last dimension (seq_len_k) # Result shape: (batch_size, ..., seq_len_q, seq_len_k) attention_weights = F.softmax(attention_scores, dim=-1) # 5. Multiply weights by Value vectors V # Result shape: (batch_size, ..., seq_len_q, d_v) output = torch.matmul(attention_weights, value) return output, attention_weights Let's break down the steps within the function: 1. **Matrix Multiplication ($QK^T$)**: We compute the dot product between each query vector and all key vectors. `torch.matmul` handles the batching and matrix multiplication. The `key.transpose(-2, -1)` operation swaps the last two dimensions of the key tensor, effectively transposing the key matrices for the multiplication. This step calculates the raw alignment scores between queries and keys. 2. **Scaling**: The scores are divided by the square root of the dimension ($d_k$). As discussed previously, this scaling prevents the dot products from becoming too large, which could push the softmax function into regions with very small gradients, hindering learning. 3. **Masking (Optional)**: If a `mask` is provided, we apply it here. The mask typically has `0`s where attention should be prevented (e.g., padding tokens or future positions in a sequence) and `1`s elsewhere. We use `masked_fill` to replace the scores at masked positions (`mask == 0`) with a very large negative number (`-1e9`). When softmax is applied next, these positions will receive near-zero probability. 4. **Softmax**: The `F.softmax` function is applied along the last dimension (the sequence length dimension, `seq_len_k`). This converts the scaled scores into probability distributions, representing the attention weights. Each query position will have weights summing to 1 across all key positions. 5. **Matrix Multiplication (Weights * V)**: Finally, the attention weights are multiplied by the Value ($V$) tensor. This computes a weighted sum of the value vectors, where the weights are determined by the attention distribution. The result is the output of the attention mechanism, representing the input sequence with contextually relevant information emphasized for each query position. The function returns both the final output tensor and the attention weights, which can be useful for analysis and visualization. ### Example Usage Let's create some sample tensors and see the function in action. We'll assume a batch size of 1, a sequence length of 4, and embedding dimensions ($d_k$, $d_v$) of 8. In self-attention, Q, K, and V often derive from the same input sequence, so `seq_len_q`, `seq_len_k`, and `seq_len_v` are typically the same. ```python # Example Parameters batch_size = 1 seq_len = 4 d_k = 8 # Dimension of Key/Query d_v = 8 # Dimension of Value # Create random Query, Value tensors # In a real model, these would come from input embeddings projected by linear layers query = torch.randn(batch_size, seq_len, d_k) key = torch.randn(batch_size, seq_len, d_k) value = torch.randn(batch_size, seq_len, d_v) # Calculate attention output, attention_weights = scaled_dot_product_attention(query, key, value) print("Input Query Shape:", query.shape) print("Input Shape:", key.shape) print("Input Value Shape:", value.shape) print("\nOutput Shape:", output.shape) print("Attention Weights Shape:", attention_weights.shape) print("\nSample Attention Weights (first batch element):\n", attention_weights[0])You should see output similar to this (values will differ due to randomness):Input Query Shape: torch.Size([1, 4, 8]) Input Shape: torch.Size([1, 4, 8]) Input Value Shape: torch.Size([1, 4, 8]) Output Shape: torch.Size([1, 4, 8]) Attention Weights Shape: torch.Size([1, 4, 4]) Sample Attention Weights (first batch element): tensor([[0.1813, 0.3056, 0.3317, 0.1814], [0.2477, 0.2080, 0.3401, 0.2042], [0.2880, 0.1807, 0.2523, 0.2790], [0.3139, 0.1774, 0.2614, 0.2473]])Notice that the output shape (1, 4, 8) matches the query and value sequence length and the value dimension ($d_v$). The attention weights shape (1, 4, 4) represents the attention scores from each of the 4 query positions to each of the 4 key positions. Each row in the sample attention weights sums to approximately 1.Visualizing Attention WeightsVisualizing the attention weights can provide insights into what parts of the input sequence the model focuses on when processing a specific element. Let's use a simple heatmap for the attention_weights we just calculated.{"data": [{"z": [[0.1813, 0.3056, 0.3317, 0.1814], [0.2477, 0.2080, 0.3401, 0.2042], [0.2880, 0.1807, 0.2523, 0.2790], [0.3139, 0.1774, 0.2614, 0.2473]], "x": ["Pos 1", "Pos 2", "Pos 3", "Pos 4"], "y": ["Query Pos 1", "Query Pos 2", "Query Pos 3", "Query Pos 4"], "type": "heatmap", "hoverongaps": false, "colorscale": [[0.0, "#e9ecef"], [0.5, "#74c0fc"], [1.0, "#1c7ed6"]]}], "layout": {"title": "Example Attention Weights", "xaxis": {"title": "Positions"}, "yaxis": {"title": "Query Positions", "autorange": "reversed"}, "width": 500, "height": 450}}Attention weights visualized as a heatmap. Each cell (i, j) shows the attention weight from Query position i to Key position j. Darker blue indicates higher attention.This visualization shows how much each query position (row) attends to each position (column). In a real application, like translating "hello world" to French, you might see that when generating the French word for "world", the attention mechanism focuses heavily on the input word "world".SummaryIn this section, you implemented the Scaled Dot-Product Attention, the core computational block for attention in Transformers. You saw how to compute scores between queries and keys, scale them, optionally apply masking, normalize using softmax to get weights, and finally compute a weighted sum of values. This function is the building block used within the Multi-Head Attention mechanism discussed earlier, allowing Transformers to effectively process sequence information.