As outlined in the chapter introduction, pre-trained language models, despite their impressive capabilities, often lack the specific conditioning needed to reliably follow user instructions or adhere to desired behavioral guidelines. They are trained to predict the next token in a sequence based on massive amounts of text, but this objective doesn't directly translate into helpfulness, honesty, or harmlessness as defined by human expectations. Supervised Fine-Tuning (SFT) is a technique designed to bridge this gap by explicitly teaching the model how to respond to prompts in a preferred manner.SFT adapts a pre-trained LLM using a dataset composed of curated input prompts and their corresponding desired outputs. Think of it as providing the model with direct examples of how it should behave. Instead of learning from the implicit patterns in web-scale text, the model learns from explicit demonstrations of good responses. The process involves further training the pre-trained model on these supervised examples, typically using a standard sequence-to-sequence loss function like cross-entropy.The SFT MechanismAt its core, SFT refines the model's parameters by minimizing the difference between the model's generated output and the target output provided in the fine-tuning dataset. The process generally follows these steps:Start with a Pre-trained LLM: Select a base LLM that has already undergone extensive pre-training. This model provides the foundational knowledge and language understanding capabilities.Prepare an Instruction Dataset: Curate or generate a dataset containing pairs of (prompt, desired_response). The quality and diversity of this dataset significantly influence the outcome of SFT. Examples might include questions paired with helpful answers, instructions paired with correctly executed tasks, or dialogue turns paired with appropriate continuations.Format the Data: Each data pair is typically formatted into a single sequence, often using special tokens to demarcate the prompt and response sections. For example: <|prompt|> What is the capital of Malaysia? <|response|> The capital of Malaysia is Kuala Lumpur. <|endoftext|>.Fine-tune the Model: Train the pre-trained model on this formatted dataset. The standard training objective is to predict the next token, but critically, the loss is usually calculated only on the tokens belonging to the desired_response part of the sequence. The prompt tokens serve as context but do not contribute directly to the loss calculation or gradient updates.This targeted loss calculation is important. We want the model to learn how to generate the response given the prompt, not simply how to predict the prompt tokens themselves (which it already learned during pre-training).Consider the objective function. During pre-training, the model maximizes the likelihood of the entire text corpus, $P(text)$. In SFT, the model learns a conditional probability: given a specific prompt, it maximizes the likelihood of the desired response, $P(response | prompt)$. This shift focuses the model on generating appropriate outputs conditioned on instructional inputs.Visualizing the SFT FlowWe can visualize the basic flow of information during a single SFT training step:digraph SFT_Flow { rankdir=TB; node [shape=box, style=rounded, fontname="Arial", color="#adb5bd", fontcolor="#495057", fontsize=12]; edge [color="#495057", fontsize=12]; splines=true; Prompt [label="Input Prompt", shape=cds, color="#74c0fc", fontcolor="#1c7ed6"]; Dataset [label="Instruction\nDataset", shape=cylinder, color="#ffec99", fontcolor="#f59f00"]; PreTrainedLLM [label="Pre-trained\nLLM", color="#b2f2bb", fontcolor="#37b24d"]; DesiredResponse [label="Desired Response", shape=cds, color="#74c0fc", fontcolor="#1c7ed6"]; GeneratedResponse [label="Generated Response", shape=cds, color="#ffc9c9", fontcolor="#f03e3e"]; LossCalculation [label="Calculate Loss\n(on Response only)", shape=invhouse, color="#eebefa", fontcolor="#ae3ec9"]; GradientUpdate [label="Update\nModel Weights", shape=octagon, color="#fd7e14", fontcolor="#495057"]; Dataset -> Prompt; Dataset -> DesiredResponse; Prompt -> PreTrainedLLM; PreTrainedLLM -> GeneratedResponse; {GeneratedResponse, DesiredResponse} -> LossCalculation [arrowhead=none]; LossCalculation -> GradientUpdate; GradientUpdate -> PreTrainedLLM [style=dashed, label="Refine"]; } A simplified representation of the SFT process, showing how prompts and desired responses from the dataset are used to calculate loss and update the pre-trained LLM's weights.Masking the LossTo implement the targeted loss calculation in practice using frameworks like PyTorch, we typically create a loss mask. This mask ensures that only the tokens corresponding to the desired response contribute to the loss computation.Here's a PyTorch snippet illustrating this:import torch import torch.nn.functional as F # Assume: # - logits: Model output logits [batch_size, sequence_length, vocab_size] # - labels: Target token IDs [batch_size, sequence_length] # - prompt_lengths: Length of the prompt part for each item in the batch [batch_size] # - IGNORE_INDEX: A special index ignored by the loss function (e.g., -100) # Assume IGNORE_INDEX is defined globally, e.g.: IGNORE_INDEX = -100 def calculate_sft_loss(logits, labels, prompt_lengths): """Calculates cross-entropy loss only on response tokens.""" batch_size, sequence_length, vocab_size = logits.shape # Shift logits and labels for next-token prediction # Logits for predicting token i are at index i-1 # Labels for token i are at index i shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() # Create a loss mask # Initialize mask to 1 (calculate loss) loss_mask = torch.ones_like(shift_labels, dtype=torch.bool) # Set mask to 0 (ignore loss) for prompt tokens for i in range(batch_size): # Prompt length includes the initial token, # so mask up to prompt_length - 1 in the shifted sequence prompt_end_index = prompt_lengths[i] - 1 if prompt_end_index > 0: # Ensure there's a prompt part to mask loss_mask[i, :prompt_end_index] = 0 # Apply the mask: Where mask is 0, set label to IGNORE_INDEX masked_labels = shift_labels.masked_fill(~loss_mask, IGNORE_INDEX) # Flatten the sequence dimension for loss calculation shift_logits = shift_logits.view(-1, vocab_size) masked_labels = masked_labels.view(-1) # Calculate cross-entropy loss, ignoring IGNORE_INDEX loss = F.cross_entropy(shift_logits, masked_labels, ignore_index=IGNORE_INDEX) return loss # --- Example Usage --- # batch_size = 2 # sequence_length = 10 # vocab_size = 1000 # prompt_lengths = torch.tensor([3, 5]) # Prompt lengths for each item # dummy_logits = torch.randn(batch_size, # sequence_length, # vocab_size, # requires_grad=True) # dummy_labels = torch.randint(0, vocab_size, (batch_size, sequence_length)) # sft_loss = calculate_sft_loss(dummy_logits, dummy_labels, prompt_lengths) # print(f"Calculated SFT Loss: {sft_loss.item()}") # sft_loss.backward() # Compute gradientsThis code snippet outlines how to mask the loss computation, ensuring only the response tokens influence the model updates during SFT.Purpose and Goals of SFTSupervised Fine-Tuning serves several important alignment objectives:Instruction Following: Teaching the model to understand and execute commands or answer questions posed in natural language prompts.Format Adherence: Training the model to generate outputs in specific formats (e.g., JSON, Markdown, code blocks, specific conversational styles).Improved Controllability: Making the model's behavior more predictable and aligned with user intentions for specific tasks.Initial Safety & Helpfulness: Introducing basic safety constraints and helpful conversational patterns by providing examples of desired interactions (e.g., refusing harmful requests, providing polite responses).While SFT is effective for teaching the model what kind of response is desired based on examples, it doesn't inherently capture human preferences perfectly. It teaches the model to imitate the style and content of the provided responses. For more complex alignment goals, such as judging the relative quality between multiple plausible responses or optimizing for qualities like "helpfulness," SFT is often followed by techniques like Reinforcement Learning from Human Feedback (RLHF), which we will discuss in the next chapter. SFT provides a foundation, equipping the model with the basic ability to follow instructions before it's further refined using preference-based methods.