Forem

Shrijith Venkatramana
Shrijith Venkatramana

Posted on

Fixing A Bug in micrograd BackProp (As Explained by Karpathy)

Hi there! I'm Shrijith Venkatrama, founder of Hexmos. Right now, I’m building LiveAPI, a tool that makes generating API docs from your code ridiculously easy.

A Bug In Our Code

In the previous post, we got automatic gradient calculation going for the whole expression graph.

However, it has a tricky bug. Here's a sample program that invokes the bug:

a = Value(3.0, label='a')
b = a + a  ;  b.label = 'b'

b.backward()
draw_dot(b)
Enter fullscreen mode Exit fullscreen mode

Buggy Graph

In the above, forward pass looks alright:

b = a + a = 3 + 3 = 6
Enter fullscreen mode Exit fullscreen mode

But think about the backward pass:

b = a + a
db/da = 1 + 1 = 2
Enter fullscreen mode Exit fullscreen mode

The answer should be 2, but we've got 1 as the a.grad value.

The problem is in the __add__ operation of Value class:

class Value:
  def __init__(self, data, _children=(), _op='', label=''):
    self.data = data
    self._prev = set(_children)
    self._op = _op
    self.label = label
    self.grad = 0.0
    self._backward = lambda: None # by default doesn't do anything (for a leaf
                                  # node for ex)

  def __repr__(self):
    return f"Value(data={self.data})"

  def __add__(self, other):
    out = Value(self.data + other.data, (self, other), '+')
    # out.grad = 1 here

    # derivative of '+' is just distributing the grad of the output to inputs
    def backward():
      self.grad = 1.0 * out.grad # a.grad = 1
      other.grad = 1.0 * out.grad # again a.grad = 1

    out._backward = backward
Enter fullscreen mode Exit fullscreen mode

Here is another example of a bug:

a = Value(-2.0, label='a')
b = Value(3.0, label='b')

d = a * b   ;   d.label = 'd'
e = a + b   ;   e.label = 'e'
f = d * e   ;   f.label = 'f'

f.backward()
draw_dot(f)
Enter fullscreen mode Exit fullscreen mode

Another Bug Example

We know that for multiplication operation:

self.grad = other.data * out.grad

d.grad = e.data * out.grad = 1 * 1 = 1

e.grad = d.data * out.grad = -6 * 1 = -6 
Enter fullscreen mode Exit fullscreen mode

So far, so good.

Let's look for the next stage:

self.grad = other.data * out.grad

b.grad = a.data * d.grad = -2 * 1 = -2

But, if we consider the expression,

e = a + b

a.grad = b.grad = e.grad = -6
Enter fullscreen mode Exit fullscreen mode

So we have the conflict - of b.grad = -6 (addition) and b.grad = -2 (multiplication)

So the general problem here is that - when a Value is used multiple times, there is a conflict and overwriting happens.

So first maybe the grad results of addition are updated, but then in another iteration the grad results of multiplication are also updated - overwriting the previous value.

Solving the bug - "Accumulate Gradients" rather than Replacing Them

The Wikipedia page for Chain Rule a section on multivariable case.

The gist of the general solution is that gradients must be accumulated, rather than replaced, in calculating gradients.

So, the new Value class is as follows where in _backwards we accumulate, rather than replace gradients:

class Value:
  def __init__(self, data, _children=(), _op='', label=''):
    self.data = data
    self._prev = set(_children)
    self._op = _op
    self.label = label
    self.grad = 0.0
    self._backward = lambda: None # by default doesn't do anything (for a leaf
                                  # node for ex)

  def __repr__(self):
    return f"Value(data={self.data})"

  def __add__(self, other):
    out = Value(self.data + other.data, (self, other), '+')

    # derivative of '+' is just distributing the grad of the output to inputs
    def backward():
      self.grad += 1.0 * out.grad
      other.grad += 1.0 * out.grad

    out._backward = backward

    return out

  def __mul__(self, other):
    out = Value(self.data * other.data, (self, other), '*')

    # derivative of `mul` is gradient of result multiplied by sibling's data
    def backward():
      self.grad += other.data * out.grad
      other.grad += self.data * out.grad

    out._backward = backward

    return out

  def tanh(self):
      x = self.data
      t = (math.exp(2*x) - 1) / (math.exp(2*x) + 1)
      out = Value(t, (self, ), 'tanh')

      # derivative of tanh = 1 - (tanh)^2
      def backward():
        self.grad += (1 - t**2) * out.grad

      out._backward = backward
      return out

  def backward(self):
    topo = []
    visited = set()
    def build_topo(v):
        if v not in visited:
            visited.add(v)
            for child in v._prev:
                build_topo(child)
            topo.append(v)
    build_topo(self)

    self.grad = 1.0
    for node in reversed(topo):
        node._backward()

Enter fullscreen mode Exit fullscreen mode

Now the gradient calculations are correct:

Calc1

Calc2

Reference

The spelled-out intro to neural networks and backpropagation: building micrograd - YouTube

Top comments (0)