Tuesday, May 16, 2023

Haskell is in my opinion one of the most extraordinarily unique languages ever made. It was originally developed for teaching and research purposes, but brought forth a number of now foundational principles such as type classes and monadic IO (shoutout Wikipedia). Basically, Haskell is super fun to use!

Standard disclaimer: This is an exploration of Haskell and automatic differentiation, certainly not code anyone would want to use in production.

What We Are Building

====================================================================================================================================================================================================================================================================================================================================

Automatic differentiation is the heart and sole powering libraries like Tensorflow, PyTorch, Jax and almost any other deep learning library. During the forward step of the training cycle of a feed forward neural network, the operations performed are stored on what is typically called a tape. At the end of the feed-forward stage, the tape is played-back, or back-propogated to get the gradients for the network. Today we will be implementing this storing of operations and back-propogating in a small Haskell program.

Defining Our Types

====================================================================================================================================================================================================================================================================================================================================

The goal of our program is to perform some mathematical computations and get the gradients of that computation. Most deep learning libraries use the abstraction of tensors, and we will be no different.

```
data (Fractional a, Eq a) => Tensor0D a = Tensor0D
{ tid :: Int,
value :: a
}
deriving (Show, Eq)
```

Our tensor is incredibly simple, it has a `tid`

and a `value`

. Note that we allow any types that are instances of both `Fractional`

and `Eq`

. We will stay in the first dimension.

Tensors should be able to perform operations, and we need a way to store those operations for use later.

```
data Operator = MP | DV | AD | NA
deriving (Eq)
data (Fractional a, Eq a) => Operation a = Operation Operator (Tensor0D a) (Tensor0D a) (Tensor0D a)
deriving (Eq)
data (Fractional a, Eq a) => Tape a = Tape
{
operations:: [Operation a],
nextTensorId:: Int
}
```

The above code gives a few more types that we use to store tensor operations. Notice that as mentioned above, like most deep learning libraries, we have our own `Tape`

. In our case our `Tape`

also stores the `nextTensorId`

. If we were writing in an imperative language like Rust, we probably would store the `nextTensorId`

in an atomic, but because Haskell does not support that kind of programming, we have the `Tape`

store the `nextTensorId`

.

We will also write a helper function to make creating tensors easier.

```
createTensor :: (Fractional a, Eq a) => a -> State (Tape a) (Tensor0D a)
createTensor value = do
tape <- get
let tensorId = nextTensorId tape
put $ tape {nextTensorId = tensorId + 1}
return $ Tensor0D tensorId value
```

Notice that we are working with the `State`

monad. Using the `State`

monad to wrap the tape is similar in idea to Tensorflow's "with tf.GradientTape()". For instance, if we wanted to create a tensor in Tensorflow with the context of monitoring operations it might look like the following:

```
with tf.GradientTape() as tape:
newTensor = tf.constant(1)
# Some series of operations that will be added to the tape
```

Our Operations

====================================================================================================================================================================================================================================================================================================================================

We will pursue relative simplicity and only implement three operations:

- Add
- Multiply
- Divide

```
tAdd :: (Fractional a, Eq a) => Tensor0D a -> Tensor0D a -> State (Tape a) (Tensor0D a)
tAdd t1@Tensor0D { tid = id1, value = value1 } t2@Tensor0D { tid = id2, value = value2 } = do
tape <- get
let tensorId = nextTensorId tape
let ops = operations tape
let newTensor = Tensor0D tensorId (value1 + value2)
put $ tape {nextTensorId = tensorId + 1, operations = Operation AD newTensor t1 t2 : ops}
return newTensor
tMul :: (Fractional a, Eq a) => Tensor0D a -> Tensor0D a -> State (Tape a) (Tensor0D a)
tMul t1@Tensor0D { tid = id1, value = value1 } t2@Tensor0D { tid = id2, value = value2 } = do
tape <- get
let tensorId = nextTensorId tape
let ops = operations tape
let newTensor = Tensor0D tensorId (value1 * value2)
put $ tape {nextTensorId = tensorId + 1, operations = Operation MP newTensor t1 t2 : ops}
return newTensor
tDiv :: (Fractional a, Eq a) => Tensor0D a -> Tensor0D a -> State (Tape a) (Tensor0D a)
tDiv t1@Tensor0D { tid = id1, value = value1 } t2@Tensor0D { tid = id2, value = value2 } = do
tape <- get
let tensorId = nextTensorId tape
let ops = operations tape
let newTensor = Tensor0D tensorId (value1 / value2)
put $ tape {nextTensorId = tensorId + 1, operations = Operation DV newTensor t1 t2 : ops}
return newTensor
```

Each operation performs the same process:

- Get the
`tape`

- Get the
`nextTensorId`

- Create the new tensor
- Store the operation in the state's
`tape`

- Put the new tensor in the state context

Going Backwards

We have our `Tape`

storing our `Operations`

, now we need to go backwards through those operations to get our gradients. What does it mean to go backwards? Let's say we have the following Haskell code (this is valid code in the context of our program):

```
doComputations :: (Fractional a, Eq a) => State (Tape a) (Tensor0D a)
doComputations = do
t0 <- createTensor 1
t1 <- createTensor 2
t2 <- createTensor 3
t3 <- tMul t0 t1
t4 <- tMul t3 t2
return t2
```

Let's imagine computing the gradients for this by hand. We might choose to draw out a parse tree.

```
*
/ \
* t2
/ \
t0 t1
```

With the values filled in (tensor | operation, value):

```
(*, 6)
/ \
/ \
/ \
(*, 2) (t2, 3)
/ \
/ \
/ \
(t0, 1) (t1, 2)
```

If we start at the top, we can go backwards (represented as b) down the tree filling in the derivatives, multiplying through operations exactly how the chain rule teaches us. This is really just a different way to view the chain rule.

```
(*, 6)
/ \
(b, 3) / \ (b, 2)
/ \
(*, 2) (t2, 3)
/ \
(b, 2) / \ (b, 1)
/ \
(t0, 1) (t1, 2)
```

To calculate the derivative for a tensor we simply follow the chain from the top multiplying each (b, value) together.

- t0 = 3 * 2
- t1 = 3 * 1
- t2 = 2

We can utilize this exact method to calculate the derivatives programmatically. Recall that `Tape`

stores a list of `Operations`

. We want to convert that list into a tree that follows the structure we wrote above, and then go backwards down the tree to get the derivatives.

Let's first build the tree.

```
data (Fractional a, Eq a) => TensorTree a = Empty | Cons (Tensor0D a) Operator (TensorTree a) (TensorTree a) deriving (Eq)
appendTree :: (Fractional a, Eq a) => Operation a -> TensorTree a -> TensorTree a
appendTree (Operation op t1 t2 t3) Empty = Cons t1 op (Cons t2 NA Empty Empty) (Cons t3 NA Empty Empty)
appendTree fullOp@(Operation op t1@Tensor0D { tid = opId } t2 t3) tree@(Cons treeTop@Tensor0D { tid = id } treeOp leftTree@(Cons Tensor0D { tid = leftId, value = leftValue } _ _ _) rightTree@(Cons Tensor0D { tid = rightId, value = rightValue } _ _ _))
| opId == leftId = Cons treeTop treeOp (Cons t1 op (Cons t2 NA Empty Empty) (Cons t3 NA Empty Empty)) rightTree
| opId == rightId = Cons treeTop treeOp leftTree (Cons t1 op (Cons t2 NA Empty Empty) (Cons t3 NA Empty Empty))
| otherwise =
let newLeftTree = appendTree fullOp leftTree
newRightTree = appendTree fullOp rightTree
in if newLeftTree /= leftTree
then Cons treeTop treeOp newLeftTree rightTree
else Cons treeTop treeOp leftTree newRightTree
appendTree _ tree@(Cons _ _ Empty Empty) = tree
buildTree :: (Fractional a, Eq a) => [Operation a] -> TensorTree a -> TensorTree a
buildTree (x:y) tree = buildTree y $ appendTree x tree
buildTree _ tree = tree
```

We introduced one new type `TensorTree`

, a recursive data structure that can be `Empty`

or have an `Operator`

with a left and right tree.

The function `buildTree`

takes a list of `Operations`

and a current `TensorTree`

, and returns a new `TensorTree`

. The function itself is pretty boring and kind of gross, further exploration of this monstrosity doesn't feel necessary.

```
applyGrads :: (Fractional a) => Operator -> a -> a -> a -> (a, a)
applyGrads op parentGrads leftValue rightValue
| op == MP = (parentGrads * rightValue, parentGrads * leftValue)
| op == DV = (parentGrads * (1 / rightValue), parentGrads * (-1) * (leftValue / (rightValue * rightValue)))
| op == AD = (parentGrads, parentGrads)
backTree :: (Fractional a, Eq a) => TensorTree a -> Map.Map Int a -> Map.Map Int a
backTree (Cons Tensor0D { tid = id } op leftTree@(Cons Tensor0D { tid = leftId, value = leftValue } _ _ _) rightTree@(Cons Tensor0D { tid = rightId, value = rightValue } _ _ _)) map =
let pGrads = Map.findWithDefault 1 id map
(leftGrads, rightGrads) = applyGrads op pGrads leftValue rightValue
leftMap = Map.delete id $ Map.insert leftId leftGrads map
rightMap = Map.insert rightId rightGrads map
in Map.unionWith (+) (backTree leftTree leftMap) (backTree rightTree rightMap)
backTree (Cons Tensor0D { tid = id } op Empty Empty) map = map
```

The function `backTree`

takes a `TensorTree`

a map, and returns an updated map with the gradients of the tensors in the `TensorTree`

.

We have also created a helper function `applyGrads`

which takes an `Operator`

and left and right `Fractional`

types, and returns the grads for left and right values for that operation.

Tying it all together

Let's augment our `doComputations`

function to include more computations and return the gradients and final tensor. We will also write a helper function to facilitate building the `TensorTree`

and going backwards through the `TensorTree`

aptly called `backward`

.

```
backward :: (Fractional a, Eq a) => State (Tape a) (Map.Map Int a)
backward = do
tape <- get
let ops = operations tape
let tree = buildTree ops Empty
return $ backTree tree Map.empty
doComputations :: (Fractional a, Eq a) => State (Tape a) (Tensor0D a, Map.Map Int a)
doComputations = do
t0 <- createTensor 1.5
t1 <- createTensor 2.5
t2 <- createTensor 3.5
t3 <- createTensor 4.5
t4 <- tMul t0 t1
t5 <- tDiv t4 t2
t6 <- tAdd t5 t3
grads <- backward
return (t6, grads)
```

To execute this code, we include this very simple `main`

function:

```
main :: IO ()
main = do
let (tensor, grads) = evalState doComputations newTape
print tensor
print grads
```

Running the final program produces:

```
Tensor0D {tid = 6, value = 5.571428571428571}
fromList [(0,0.7142857142857142),(1,0.42857142857142855),(2,-0.30612244897959184),(3,1.0),(4,0.2857142857142857),(5,1.0)]
```

Which when compared with https://www.derivative-calculator.net/ is correct!

Thank you for reading!

----------------------------------------

Github | Twitter | LinkedIn | Newsletter

© 2024 Silas Marvin. No tracking, no cookies, just plain HTML and CSS.