user-icon Philip Ossenkopp
03. December 2018
timer-icon 6 min

Reinforcement learning – Part 2: Getting started with Deep Q-Networks

In the last part of this reinforcement learning series, we had an agent learn Gym’s taxi-environment with the Q-learning algorithm. We achieved decent scores after training our agent for long enough. But this approach reaches its limits pretty quickly. Without spoiling too much, the observation-space of the environment in the next post has a size of 10174. That is greater than the total number of atoms in the observable universe! If the universe can’t handle that large numbers, our Q-table certainly won’t. Instead we will be using another big topic in machine learning: deep neural networks. They are exceptionally good in coming up with new suited features for complex data. So let’s give it a try and create one ourselves.

How DNNs fit in our agent

Today we will construct an agent that is able to play the taxi game from last time and the CartPole-v1 environment. With these simple challenges we have a smooth introduction on how to apply deep neural networks to RL. You can see the interaction between agent and environment in the figure below.

You can think about the DNN as a black-box. It takes a game-state as input and returns a Q-value approximation for every possible action. After that we choose the action with the greatest Q-value – just like Q-learning does. But in difference to that, we are now trying to recognize patterns instead of mapping every state to its best action. This wouldn’t be possible in environments with huge state-spaces. In order for our neural net to predict based on the environment, we have to feed it pairs of input and output. The neural net will train on that data to approximate the output based on the input by updating the parameters iteratively.

Implementing the agent

For our agent we will be using a class. It provides a clean interface we can interact with. As the major changes to last time only affect the agent’s decision process, we can recycle some code of the Q-learning agent. We put that into the agent’s train() function. Later we can call agent.train() and the agent will start.

We initialize the class by creating the environment and adopting all its specific parameters. You can ignore memory and batchsize for now. They are required for experience replay, which I explain later on. The gamma is our discount-factor as last time and the learning-rate with decay is needed for the optimizer in our DNN. Also new is the  win_threshold, which determines whether an environment is solved. For the CartPole game the agent has to score an average of 195 over 100 consecutive trials. The taxi game is solved if this value exceeds 9.7.

Building the model is outsourced into an extra function. To create our neural net we use Keras Sequential API, as it’s easy to understand and requires only a few lines of code. If you are totally new to that I recommend reading their Getting Started.

The Sequential() function creates a linear stack of layers, where new layers are added to. As our environment is relatively simple, we will just use three densely connected layers and add them. The first layer gets the game-state, has 24 neurons and is activated by tanh. Following comes a Dense-layer with 48 neurons with the same activation function. The output-layer uses linear activation and yields the Q-value for every action. Before returning the final model, we need to compile it. This step needs an optimizer and a loss function. Adam with the standard learning-rate is a great optimizer as it’s appropriate to a lot of problems and hyper-parameters require only a little to no tuning. Our loss function will be the mean squared error.

The act()  function takes a state and either uses the model to predict the highest Q-value or just takes a random action. This decision depends on the current exploration-rate saved in the epsilon and provides the exploration-explotation trade-off again.

To improve the learning, our DQN agent uses experience replay. It remembers a number of (s, a, r, s’, done)-tuples and after each episode trains with a batch of them. This is like you learning vocabulary and repeating some that you already know. The DQN needs experience replay to reduce correlations in the sequence of observations, which otherwise might drive the network into a local minimum. For our vocabulary example this is comparable to you repeating vocabulary not with a list where your brain remembers the order of words, but learning with cards which is much more efficient as you can shuffle the words. The original DQN of DeepMind uses a lot more tweaks for better learning, but for our purpose this is enough.

The remember function – as its name states – appends tuples to the agent’s memory.

More complex in contrast is the replay method. We take a batch of our given size if possible. Then we iterate through all “memories” in the batch and train our model on them. We start by calculating the action prediction of our model for the state. After that we update the Q-value of the action we took. If the game was finished after that action, we just take the reward as new Q-value (as there is no further future reward). In case of the game continuing, we take the sum of the reward and the discounted future reward. All the updated pairs of states and Q-values are given to the fit() function, which trains the model on them. The approximation of the Q-values converges to the true Q-values as we repeat this updating process. The loss will decrease and the agent’s score will grow higher. Last, we decrease our exploration-rate if needed.

And that’s everything we need for our DQN-agent. We can start training now and take a look at how efficient this agent is.

Can the agent adapt?

The agent utilizing a table from last time reached human-level performance in minutes without knowing anything about the environment. This time we have used a deep neural network to compete in two different settings. And as you can see in the graphs below we were kind of successful. Both times the agent started off exploring and learned with every game it played. To stabilize the learning one could have used a target network, but it also worked without. Also you can see that neural networks aren’t always the appropriate solution. The agent using the table learnt a lot faster in the taxi environment for example.

CartPole-v1, deep q-learning

Taxi-v2, deep q-learning

But nevertheless, we’ve just created an agent that could work as taxi driver and pole acrobat. So my work is done here. Let me know if you have any questions or improvements for the agent. In the final part of my small trilogy I will show you how to implement your own environment. If you rather want to learn more about DRL, I suggest you take a look at our training!


Comment article