Reverse-mode Automatic Differentiation
Refresher on Gradients
This post will assume some knowledge of gradients, but if you’re rusty, hopefully this will be enough so that you understand the rest of the post.
The derivative of a real function (this just means it takes in a real number and spits out a real number), is given by the equation:
You can think of this as being the change in compared to the change in , which is the slope of the line (if is plotted on the y axis and is plotted on the x axis) as the change in approaches zero, meaning you find the slope exactly at the point .
So the slope of an graph at any point is given by the equation .
For a more thorough detail than we have space to here, try reading this post.
Effectively, we assume that whenever some function is minimized, we have solved our problem. So we set a new value for based on the old one, such that it minimizes the value of .
Calculating this gradient can be a pain when the becomes non-trivial, so we adopt a method for determining gradients automatically.
Automatic differentiation (autodiff) is a fancy way of saying the chain rule, but the computer does the work for you.
We only consider reverse-mode differentiation here, but there is also (the conceptually easier) forward-mode differentiation1.
(with thanks to stackoverflow2).
Say for example, we want to minimize the function :
Where are some parameters that we can update.
It turns out that we can express our loss function as a sequence of primitive operations. Drawing this as a graph, it looks like this:
Each node (i.e., one of the circles) consists of either an operation or a parameter which we can change. The edges (arrows) describe the direction that gives us this output. We label the nodes of each graph with , such that:
You’ll notice this is the same as the above computation graph, except the direction of the edges (arrows) are reversed.
We can consider the gradient of a node to be a function of only the gradient of tha parent(s) and the inputs to it, which we store in the node from the forward pass. Let’s use it work out what the gradients of the parameters are.
We will use a shorthand for the gradient called the adjoint:
To start us off, we need to recall that derivative of something with respect to itself is 1.
Now, for each dependency of , we need to calculate their own respective adjoint.
Since we know already that , using symbolic differentiation we find , so .
Similarly, , seeing as it’s follows the same equation, just with the index changed.
Hence , which evaluated at gives us , as expected. We can substitute our value in for and update using the update equation to give a gradient of:
From before , and via symbolic differentiation , evaluated at . , evaluated at (here is not needed, but in general, this is not the case). So, , evaluated at which we need to have saved from the forward pass, since they change every time.
We can now update based on the update equation:
Where is the learning rate, signifying the amount for the vector to update in the gradient direction.
If we initialized our optimizer with , then we would find that we get an initial loss value of 2.8414.
Here’s a small Python implementation of this:
from math import sin, cos x = 1 y = 2 lr = 1e-2 print("Initial loss:", sin(x) + x * y) # Initial loss: 2.8414709848078967 for _ in range(1000): # Calculate updated parameters new_x = x - lr * (cos(x) + y) new_y = y - lr * x # Set updated parameters x = new_x y = new_y # Calculate loss function print("Final loss:", sin(x) + x * y) # Final loss: -193992792.6350034 print("x:", x) # x: -13928.111031430128 print("y:", y) # y: 13928.148130601208
The final loss is much lower than the initial loss, showing that our naive optimizer has performed its job in travelling towards the minimum of the loss function. This particular function actually is minimized at , but for objective functions constrained to be above zero, such as a norm, then our method should work similarly.
You can see that in a (backwards) pass, it manages to compute all the gradients of a given computation graph. This is at the expense of having to store the intermediate results of the forward pass (or perhaps recalculate them). This means more memory is used, but for ease of use, we will use a dynamic graph (i.e defined each forward pass).
_[RAD]: Reverse-mode Automatic Differentiation _[autodiff]: Automatic Differentiation