Thursday, September 1, 2016

Breaking the Caesar cipher with Deep Reinforcement Learning


You can teach your machine to break arbitrary Caesar cipher by observing enough training examples using Trusted Region Policy Optimization for Policy Gradients:

Full text

Imagine the world where a hammer was introduced to the public just couple a years ago. Everyone is running around trying to apply the hammer to anything that even resembles a nail. This is the world we are living in and the hammer is deep learning.
Today I will be applying it to a task that can be much easier solved by other means but hey, it's Deep Learning Age! Specifically, I will teach my machine to break a simple cipher like Caesar cipher just by looking at several (actually, a lot) examples of English text and corresponding encoded strings.
You may have heard that machines are getting pretty good at playing games so I decided to formulate this code breaking challenge as a game. Fortunately there is this OpenAI Gym toolkit that can be used "for developing and comparing reinforcement learning algorithms". It provides some great abstractions that help us define games in terms that computer can understand. For instance, they have a game (or environment) called "Copy-v0" with the following setup and rules:
  • There is an input tape with some characters.
  • You can move cursor one step left or right along this tape.
  • You can read symbols under the cursor and output characters one at a time to the output tape.
  • You need to copy input tape characters to output tape to win.
But this is almost exactly what we need! Let's just change the win condition: instead of just copying input tape characters you need to decode them first to win.
Now let's talk a bit about the hammer itself. The hottest thing on the Reinforcement Learning market right now is Policy Gradients and specifically this flavorTrust Region Policy Optimization. There is an amazing article from Andrej Karpathy on Policy Gradients so I will not give here an introduction. If you are new to Reinforcement Learning you just stop reading this post and go read that one. Seriously, it's so much better!

Still here? Ok, I will tell you about TRPO then. TRPO is a technique for Policy Gradients optimization that produces much better results than vanilla gradient descent and even guarantees (theoretically, of course) that you can get an improved policy network on every iteration.
With vanilla PG you start by defining a policy network that produces scores for the actions given the current state. You then simulate hundreds and thousands of games taking actions suggested by the network and note which actions produced better results. Having this data available you can then use backpropagation to update your policy network and start all over again. The only thing that TRPO adds to this is that you solve a constrained optimization problem instead of an unconstrained one: $$ \textrm{maximize } L(\theta) \textrm{ subject to } \bar{D}_{KL}(\theta_{\textrm{old}},\theta)<\delta$$ Here \(L(\theta)\) is a loss that we are trying to optimize. It is defined as $$E_{a \sim q}[\frac{\pi_\theta(a|s_n)}{q(a|s_n)} A_{\theta_{\textrm{old}}}(s_n,a)],$$ where \(\theta\) is our weights vector, \(\pi_\theta(a|s_n)\) is a probability (score) of the selected action \(a\) in state \(s_n\) according to the policy network, \(q(a|s_n)\) is a corresponding score using the policy network from the iteration before and \(A_{\theta_{\textrm{old}}}(s_n,a)\) is an advantage (more on it later). Running simple gradient descent on this is the vanilla Policy Gradients approach. TRPO approach doesn't blindly descend along the gradient but takes into account the \(\bar{D}_{KL}(\theta_{\textrm{old}},\theta)<\delta\) constraint. To make sure the constraint is satisfied we do the following. First, we approximately solve the following equation to find a search direction: $$Ax = g,$$ where A is the Fisher information matrix, \(A_{\textrm{ij}} = \frac{\partial}{\partial \theta_i}\frac{\partial}{\partial \theta_j}\bar{D}_{KL}(\theta_{\textrm{old}},\theta)\) and \(g\) is the gradient that you can get from the loss using backpropagation. This is done using conjugate gradients algorithm. Once we have a search direction we can easily find a maximum step along this direction that still satisfies the constraint.
One thing that I promised to get back to is the advantage. It is defined as $$A_\pi(s,a)= Q_\pi(s,a)−V_\pi(s),$$ where \(Q_\pi(s,a)\) is a state-action value function (actual reward of taking an action in this state, it usually includes discounted rewards for all upcoming states) and \(V_\pi(s)\) is a value function (in our case it's just a separate network that we train to predict the value of the state).

Bored enough already? I promise, it's not that scary in code. You can find the full implementation here: tilarids/reinforcement_learning_playground. Specifically, look at You can reproduce the Caesar cipher breaking by running
For those of you who thinks the code resembles wojzaremba's implementation a lot - you are right. I was copying some TRPO code from there and then rewriting it to make it more readable and also to make sure it follows the paper closely.


As you can see, I've used The Zen of Python (which is already conveniently encoded with ROT13) to generate my training data. There are sample runs after 45000, 100000 and 164000 episodes (each episode == 1 game). After 164000 the agent is smart enough to decode long sentences: You may say (and you will be right) that it's an overkill to use such a complex approach to solve such a simple problem and there are much simpler ways to break Caesar cipher having large sets of encoded and decoded text pairs. But the beauty of the approach described above is that it can be trained to solve other tasks without a slightest change in the agent code and network configuration. For instance, here the same agent balances a pole on the cart and here it learns how to copy symbols. So this whole article is only half joke. Deep Reinforcement Learning using TRPO is a powerful technique and I look forward towards the future with fully self-trained robots strolling the streets. Scary, huh?


  1. Thanks for a nice blog post and the code. I have followed the code and have two questions:

    1) Why are you using the conj_grads_damping (=0.1)? Is this related somehow to conjugate gradients in general, or something in TRPO? I've seen something like it in general conjugate gradients, but then they have a convex combination like (1-conj_grads_damping) * A + conj_grads_damping * B.

    2) Do you understand what's going on in the TRPO appendix C.1 "Computing the Fisher-Vector Product", when they introduce the mean-vector mu? I don't understand the weird kl-divergence discussion there (using some small kl, rather than D_kl etc..). And if you did understand that part, is it somehow apparent in the code too?


    1. I haven't touched this code for almost a year so some of my understanding of TRPO may have faded. But here you go:
      1) This is not directly related to TRPO but is related to the conjugate gradients method. I've seen some different damping methods when conjugate gradients are involved and this is one of them. Removing this part should not break the algorithm in theory but in reality we start hitting the safeguard and this breaks the line_search and slows down the learning (compare that uses damping and that doesn't). There is a short "discussion" about introducing this kind of damping in the repo I picked up the damping (and most of the other ideas) from:
      2) As far as I remember, $mu_\theta(x)$ is `self.the policy_network` in the code and $\mu_{old}$ is `self.prev_policy`. `self.kl_divergence_op` is how KL-divergence is computed. `M` from the article is the second derivation of the KL-divergence and that's what is computed from `self.kl_divergence_op` and is finally used in `fisher_vector_product`.

    2. MathJAX is very broken, reposting the second part of the answer.
      As far as I remember, `\mu_\theta(x)` is "self.policy_network" in the code and `\mu_{old}` is "self.prev_policy". "self.kl_divergence_op" is how KL-divergence is computed. `M` from the article is the second derivation of the KL-divergence and that's what is computed from "self.kl_divergence_op" and is finally used in "fisher_vector_product" function.

    3. Hi Sergey,

      Thank you so much for taking the time to answer these questions, it helped a lot!