Linear function approximation, as we've seen, provides a way to generalize value estimates across states using features. However, the real world (and many interesting simulated environments) often exhibits complex, non-linear relationships between a situation (state) and its long-term value. A simple linear combination of features might not be expressive enough to capture these intricate patterns. For instance, the value of a state might depend on subtle interactions between features that linear models struggle to represent.
This is where Neural Networks (NNs) come into play. NNs are powerful function approximators, renowned for their ability to learn complex, non-linear mappings directly from data. Their success in fields like computer vision and natural language processing stems from this capability. In Reinforcement Learning, we can leverage NNs to approximate value functions, potentially leading to much better performance in complex environments.
Instead of using a linear function v^(s,w)=wTx(s), we can use a neural network. The network takes the state representation as input and outputs the estimated value. Let w now denote the entire set of weights and biases within the neural network.
A diagram showing a neural network taking a state representation as input and outputting estimated Q-values for multiple actions. The weights w parameterize the connections within the network.
The fundamental goal remains the same as with linear VFA: adjust the parameters w (now the network's weights) to minimize the difference between the predicted value and a target value. We typically use variants of Temporal Difference (TD) learning.
For example, in a Q-learning context using an NN, the target value for an experience tuple (St,At,Rt+1,St+1) is often:
Yt=Rt+1+γa′maxq^(St+1,a′,w)The network predicts q^(St,At,w). The objective is to minimize the squared error between the target and the prediction, often called the TD error: δt=Yt−q^(St,At,w).
We update the weights w using stochastic gradient descent (SGD) or its variants (like Adam). The update aims to move the prediction closer to the target:
w←w+αδt∇wq^(St,At,w)Here, ∇wq^(St,At,w) is the gradient of the network's output (for the specific action At) with respect to its weights w. This gradient is calculated efficiently using the backpropagation algorithm, a standard technique in deep learning. Thankfully, modern deep learning libraries like TensorFlow or PyTorch handle the automatic differentiation and backpropagation for us. We only need to define the network architecture and the loss function (typically Mean Squared Error based on the TD error).
Note that this is still a semi-gradient method because the target Yt itself depends on the current weights w (unless using a target network, discussed later), and we don't differentiate through the target calculation when computing the gradient.
While powerful, directly combining NNs with TD learning introduces potential instability during training. Two main issues arise:
Furthermore, NNs introduce more hyperparameters (network architecture, layer sizes, learning rates, activation functions) that require careful selection and tuning.
Using neural networks for value function approximation marks the transition towards Deep Reinforcement Learning (DRL). The next chapter, "Introduction to Deep Q-Networks (DQN)", will directly address the stability challenges mentioned above by introducing techniques like Experience Replay and Fixed Q-Targets, which were instrumental in the success of early DRL algorithms. These techniques allow us to effectively train deep neural networks for complex RL tasks.
© 2025 ApX Machine Learning