DEV Community

Shrijith Venkatramana
Shrijith Venkatramana

Posted on

Representing Math Expressions As Graphs in micrograd

Hi there! I'm Shrijith Venkatrama, founder of Hexmos. Right now, I’m building LiveAPI, a tool that makes generating API docs from your code ridiculously easy.

What We Want to Represent

In the last post, we manually did some slope calculation like so:

h = 0.001

#inputs 
a = 2.0
b = -3.0
c = 10.0

d1 = a*b + c
c += h
d2 = a*b + c


print(f"d1 = {d1}")
print(f"d2 = {d2}")
print(f"w.r.t a (d2 - d1) / h = {(d2 - d1) / h}")
Enter fullscreen mode Exit fullscreen mode

The goal is to represent the above expression L = a*b + c in an easy way, and then do critical operations on it, such as find dL/da, dL/db, dL/dc, etc. This is important because neural network training is about adjusting the values of each node in the expression graph, until the inputs are mapped to output in a desirable way.

Building the Value class (foundation for getting neural networks)

The first step in fulfill the above is to represent a single value.

Iteration 1: Represent a single Value

class Value:
  def __init__(self, data):
    self.data = data

  def __repr__(self):
    return f"Value(data={self.data})"
Enter fullscreen mode Exit fullscreen mode

Setting an example Value

Iteration 2: Represent Multiple Values And Operations on Them

We add addition and multiplication supports on the Value class so that we can do a + b or a * b + c

class Value:
  def __init__(self, data):
    self.data = data

  def __repr__(self):
    return f"Value(data={self.data})"

  def __add__(self, other):
    return Value(self.data + other.data)

  def __mul__(self, other):
    return Value(self.data * other.data)

a = Value(2.0)
b = Value(-3.0)
print(a*b)
c = Value(10)
print(a * b + c)
print((a.__mul__(b)).__add__(c)) # same as above
Enter fullscreen mode Exit fullscreen mode

Value Operators and Expressions

Iteration 3: Store whole expressions

The next step is to store the "whole chain" of values and operations in a nice graph.

The way this is done is via introducing two new object attributes: _prev and _op. For each node - we record what are the nodes beneath/before it. And also - we specify what operation was performed between those nodes that came before to get the present node.

class Value:
  def __init__(self, data, _children=(), _op=''):
    self.data = data
    self._prev = set(_children)
    self._op = _op

  def __repr__(self):
    return f"Value(data={self.data})"

  def __add__(self, other):
    return Value(self.data + other.data, (self, other), '+')

  def __mul__(self, other):
    return Value(self.data * other.data, (self, other), '-')

a = Value(2.0)
b = Value(-3.0)
c = Value(10)
e = a * b
d = e + c
print(d._prev)
print(d._op)
print("---")
print(e._prev)
print(e._op)`
Enter fullscreen mode Exit fullscreen mode

The Whole Expression

Visualizing the Expression Graph

Karpathy shares a nice bit of code built on top of GraphViz to display the expressions as a graph. Values are represented in squaraes, and operations in ellipses:

from graphviz import Digraph

def trace(root):
    # Builds a set of all nodes and edges in a graph
    nodes, edges = set(), set()

    def build(v):
        if v not in nodes:
            nodes.add(v)
            for child in v._prev:
                edges.add((child, v))
                build(child)

    build(root)
    return nodes, edges

def draw_dot(root):
    dot = Digraph(format='svg', graph_attr={'rankdir': 'LR'})  # LR = left to right

    nodes, edges = trace(root)
    for n in nodes:
        uid = str(id(n))
        # For any value in the graph, create a rectangular ('record') node for it
        dot.node(name=uid, label="{ data %.4f }" % (n.data,), shape='record')

        if n._op:
            # If this value is a result of some operation, create an op node for it
            dot.node(name=uid + n._op, label=n._op)
            # And connect this node to it
            dot.edge(uid + n._op, uid)

    for n1, n2 in edges:
        # Connect n1 to the op node of n2
        dot.edge(str(id(n1)), str(id(n2)) + n2._op)

    return dot
Enter fullscreen mode Exit fullscreen mode

I can do the following to get an image of the graph:

draw_dot(d) # where d is the expression defined above
Enter fullscreen mode Exit fullscreen mode

Graph Image

Reference

The spelled-out intro to neural networks and backpropagation: building micrograd)

Top comments (0)