How Gradients Are Calculated?#

Warning

This part may be more mathematics focused. If you simply want to grasp the intuition behind deep learning, feel free to skip the section.

Note

This explanation will focus on how PyTorch calculates gradients. Recently TensorFlow has switched to the same model so the method seems pretty good.

Chain rule#

\[ \frac{d f}{d x} = \frac{d f}{d y} \frac{d y}{d x} \]

Chain rule is basically a way to calculate derivatives for functions that are very composed and complicated. With chain rule at hand, we will be able to take derivatives for functions that look familiar but complicated.

Example usage of chain rule#

It’s very easy to calculate the derivative of \( 5 x \), which is \( 5 \). It’s also obvious that the derivative of \( x^3 \) is \( 3 x^2 \). However, what’s the gradient of \( (5 x)^3 \)?

It’s easy, you say. \( (5 x)^3 = 125 x^3 \), so the derivative is \( 125 (3 x^2) = 375 x^2 \). You’re right. For demonstration purpose, let’s see how chain rule can derive the same answer.

First, let \( y = 5 x \). Then, the chain rule:

\[ \frac{d f}{d x} = \frac{d f}{d y} \frac{d y}{d x} \]

will reduce to:

\[ \frac{d (5 x)^3}{d x} = \frac{d (5 x)^3}{d (5 x)} \frac{d (5 x)}{d x} \]

We know that \( \frac{d x^3}{x} = 3 x^2 \).

So that: $\( \frac{d (5 x)^3}{d (5 x)} \frac{d (5 x)}{d x} = 3 (5 x)^2 5 = 375 x^2 \)$

The same answer. So why chain rule? Imagine if it’s not \( (5 x)^3 \), but \( \cos(\sin (5 x) 8 x) \)? This is where chain rule comes in handy. It can decompose the complicated function step by step and arrive at the solution that may be otherwise too complicated. I’m not going to derive the derivatives of \( \cos(\sin (5 x) 8 x) \). Try it yourself!

So how does PyTorch calculate gradients?#

Gradients are multi-dimensional derivatives. A gradient for a list of parameter \( X \) with regards to the number \( y \) can be defined as:

\[\begin{split} \begin{bmatrix} \frac{d y}{d x_1} \\ \frac{d y}{d x_2} \\ \vdots \\ \frac{d y}{d x_n} \end{bmatrix} \end{split}\]

Gradients are calculated using chain rule. Turns out that most functions you use, that PyTorch provides, are composed by easy functions, and PyTorch derives the gradients for the