Folds are a common stumbling point for people learning about the functional paradigm.
I remember being pretty confused about the difference between a left fold, a right fold, and how either of them differ from a reduce
.
I’m going to try and explain them in a way that’s easy for non-functional programmers to get.
If you don’t get it, I’ve screwed up – feel free to let me know!
To keep things simple, the example will just use singly linked lists in Java and Haskell.
Let’s get warmed up and write out our List
class.
We’re trying on our functional programming hats, so we’ll keep it fairly simple:
class List<T> {
public final T head;
public final List<T> tail;
private List(T head, List<T> tail) {
this.head = head;
this.tail = tail;
}
public static <T> List<T> Cons(T item, List<T> rest) {
return new List<T>(item, rest);
}
public static <T> List<T> Nil() {
return new List<>(null, null);
}
}
We’ve got two constructors: Cons
for putting something on a list, and Nil
for an empty list. We can build a simple list like:
Cons(1, Cons (2, Cons(3, Nil())))
And the same in Haskell:
data [t] = t : [ts]
| []
If you’re unfamiliar with Haskell, this defines two constructors: an infix constructor :
and the empty list []
constructor.
t
is a type parameter.
We can make 1 : 2 : 3 : []
like this, though the language gives us syntax sugar to write [1, 2, 3]
instead.
To continue warming up, let’s write out map
and filter
for lists.
This will help with our intuition on folds later on!
map
is a function or method usually defined on lists and arrays, though you can define it for all kinds of types.
Our intuition for a map is that – for a structure like List<A>
, we’ll have a Function<A, B>
.
We’ll take all the <A>
values in the list and return a new list with <B>
values instead.
Recursive stuff works best if you can think of the possible cases.
For lists, we’ve got two constructors: Nil
and Cons
.
If we have a Nil
, we return another Nil
.
If we have a Cons
, we return another Cons
after applying the function to the head
and map
ping over the tail:
static <A, B> List<B> map(Function<A, B> f, List<A> list) {
if (list.isNil()) {
return List.Nil();
}
return List.Cons(f.apply(list.head), map(f, list.tail));
}
Haskell gives us pattern matching, so instead of an if
, we match on the constructors:
map f [] = []
map f (head : tail) = (f head) : (map f tail)
This is a little concise, and we can possibly make it more clear by naming some of the intermediate steps.
static <A, B> List<B> map(Function<A, B> f, List<A> list) {
if (list.isNil()) {
return List.Nil();
}
A val = list.head;
List<A> rest = list.tail;
B newVal = f.apply(val);
List<B> newRest = map(f, rest);
return List.Cons(newVal, newRest);
}
map f [] = []
map f (head : tail) =
let newVal = f head
newRest = map f tail
in newVal : newRest
Cool. Alright, let’s do filter
now.
Filter takes a predicate
and returns a new list where the elements in the new list returned true for the predicate.
Filtering an empty list gives us an empty list.
When we filter a Cons
, we have two cases:
predicate
to the head
returns true
.false
.In both cases, we’ll want to continue filtering the list. In the first case, we want to keep the item. In the second case, we don’t want to keep it.
static <A> List<A> filter(Function<A, Boolean> predicate, List<A> list) {
if (list.isNil()) {
return List.Nil();
}
if (predicate.apply(list.head)) {
return List.Cons(list.head, filter(predicate, list.tail));
}
return filter(predicate, list.tail);
}
filter pred [] = []
filter pred (head : tail) =
if pred head
then head : (filter pred tail)
else filter pred tail
Finally, let’s write sum
to sum all the numbers in a list.
The sum of an empty list is 0
, and the sum of a non-empty list is the value of the head plus the sum of the tail.
Easy enough, let’s write it out:
static Integer sum(List<Integer> list) {
if (list.isNil()) {
return 0;
}
return list.head + sum(list.tail);
}
sum [] = 0
sum (x:xs) = x + sum xs
Map, filter, and sum all have some things in common:
Alright, with that, we’re ready to conquer folds.
A fold has three arguments:
foldr
is a right fold, foldl
is a left fold, and they’re defined like this:
foldr k z [] = z
foldr k z (x:xs) = k x (foldr k z xs)
foldl k z [] = z
foldl k z (x:xs) = foldl k (k z x) xs
The variable k
is our combining function.
foldl
is tail recursive, and passes the result of combining the accumulator z
with the current item on the list.
“but matt, this doesn’t help me understand”
Right. We’re getting there.
Haskell lets us use functions infix if we surround them with backticks, so we can also write foldr
like this:
foldr k z [] = z
foldr k z (x:xs) = x `k` (foldr k z xs)
The infix isn’t superfluous.
We can get a nice intuition for how foldr
works on a list with it.
Let’s see how Haskell would write out [1..3]
without any sugar:
1 : 2 : 3 : []
Take every :
and replace it with k
, and take the []
and replace it with z
:
1 `k` 2 `k` 3 `k` z
Now we can substitute k
for +
and z
for 0
and see that this is sum
:
1 + 2 + 3 + 0
We can get map
by replacing k
with (\x acc-> f x : acc)
, and we can get filter
by replacing k
with (\x acc -> if p x then x:acc else acc)
.
Let’s walk through an example, step by step:
-- initial function call
foldr 0 (+) (1 : 2 : 3 : [])
-- recurse:
-- foldr k z (x:xs) = x `k` foldr k z xs
= 1 + (foldr 0 (+) (2 : 3 : []))
-- recurse:
= 1 + (2 + (foldr 0 (+) (3 : [])))
-- recurse:
= 1 + (2 + (3 + (foldr 0 (+) [])))
-- base case:
-- foldr k z [] = z
= 1 + (2 + (3 + 0))
Yup! What about foldl
? There must be some magic there, right?
Nope, though it might look a little strange:
-- initial function call
foldl 0 (+) (1 : 2 : 3 : [])
-- recurse:
-- foldl k z (x:xs) = foldl k (k z x) xs
foldl (+) (0 + 1) (2 : 3 : [])
-- recurse:
foldl (+) ((0 + 1) + 2) (3 : [])
-- recurse:
foldl (+) (((0 + 1) + 2) + 3) []
-- base case:
-- foldl k z [] = z
(((0 + 1) + 2) + 3)
Interesting! This has nearly the same shape as what foldr
ended up looking like, but the parentheses are nested differently.
With foldr
, we directly replace []
with our z
value.
foldl
prepends the z
value to the list and just drops the []
entirely, so our foldr
“replace the :
with k
” trick needs to be adjusted slightly.
With addition, it doesn’t really matter, since you can swap arguments and parentheses around. Let’s try it with subtraction and see the difference:
foldr (-) 0 [1..3]
-- desugar the list:
= 1 : 2 : 3 : []
-- replace : and [] with right associating k and z
= 1 `k` (2 `k` (3 `k` z))
-- replace with args:
= 1 - (2 - (3 - 0))
-- evaluate:
= 1 - (2 - 3)
= 1 - (-1)
= 2
foldl (-) 0 [1..3]
= 1 : 2 : 3 : []
-- prepend with z
= z : 1 : 2 : 3 : []
-- replace : with left associating k and drop the []
= ((z `k` 1) `k` 2) `k` 3
-- replace z with 0 and k with -
= ((0 - 1) - 2) - 3
-- evaluate:
= (-1 - 2) - 3
= -3 - 3
= -6
If you’re curious about the evaluation of these functions, you can use their cousins scanr
and scanl
.
Instead of returning a single end result, they return a list of all intermediate steps.
-- In GHCi,
λ> scanr (+) 0 [1..3]
[6,5,3,0]
λ> scanl (+) 0 [1..3]
[0,1,3,6]
λ> scanr (-) 0 [1..3]
[2,-1,3,0]
λ> scanl (-) 0 [1..3]
[0,-1,-3,-6]
Alright, enough Haskell for now, let’s implement these two in Java:
static <A, B> B foldRight(BiFunction<A, B, B> k, B z, List<A> list) {
if (list.isNil()) {
return z;
}
return k.apply(list.head, foldRight(k, z, list.tail));
}
static <A, B> B foldLeft(BiFunction<B, A, B> k, B z, List<A> list) {
if (list.isNil()) {
return z;
}
return foldLeft(k, k.apply(z, list.head), list.tail);
}
And now, let’s rewrite map
and filter
in terms of these:
static <A, B> List<B> mapR(Function<A, B> f, List<A> list) {
return foldRight(
(x, acc) -> List.Cons(f.apply(x), acc),
List.Nil(),
list
);
}
static <A> List<A> filterR(Function<A, Boolean> p, List<A> list) {
return foldRight(
(x, acc) -> p.apply(x) ? List.Cons(x, acc) : acc,
List.Nil(),
list
);
}
So, at first glance, you might think that foldl
is superior.
It’s in tail recursive position, so a clever enough compiler could easily optimize it to a loop (unfortunately, Java doesn’t have tail recursion as of now).
Lacking tail recursion, though, they both have to do about the same amount of work, and seem to be equivalent.
There are two reasons why foldRight
is useful.
We can see the first by implementing map in terms of foldLeft
:
static <A, B> List<B> mapL(Function<A, B> f, List<A> list) {
return foldLeft(
(acc, x) -> acc.append(f.apply(x)),
List.Nil(),
list
);
}
Alas! This is quadratic in the size of the input list!
Appending to the end of a singly linked list is $O(n)$ time.
We can verify this by looking at the simplest definition of append
:
public List<T> append(T elem) {
return foldRight(List::Cons, Cons(elem), this);
}
So foldRight
can be useful when constructing new lists.
In fact, we can easily write concat
using foldRight
:
public static <T> List<T> concat(List<T> first, List<T> second) {
return foldRight(List::Cons, second, first);
}
The ease of writing functions like this isn’t a coincidence.
foldRight
is theoretically entwined with singly linked lists.
The ordinary definitions of data structures involve “how do I construct this,” but you can also define data structures in terms of “how do I deconstruct this.”
foldRight
is that definition.
This is referred to as the Church encoding of a list.
Another difference between foldl
and foldr
is how they work with laziness.
Laziness can be tricky to understand at first, since it defies all of our intuitions about how to evaluate code.
Consider the implementation of map
using foldr
:
map f xs = foldr (\x acc -> f x : acc) [] xs
A map
law is that composing two maps is the same as a single map with the two functions composed.
In fancy math,
In Haskell,
map f . map g = map (f . g)
In Java,
map(f, map(g, list)) = map(compose(f, g), list)
If we can fuse the two maps like this, then we can make this dramatically more efficient.
Does foldr
respect this law with respect to performance?
Let’s watch map (+1) . map (*2)
work, using our foldr
definitions.
print
will demand our values and force their evaluation.
printEach [] = print "Done"
printEach (x:xs) = do
print x
printEach xs
printEach (map (+1) (map (*2) [1, 2, 3]))
print
is what actually forces evaluation here, so nothing gets evaluated until print
forces it, and only as much as is required for print
.
First, printEach
matches on map (+1) (map (*2) [1,2,3])
.
It needs to know if that evaluates to []
or (x:xs)
.
This causes map (+1)
to match on map (*2) [1,2,3]
.
Which causes map (*2)
to match on [1,2,3]
.
printEach (
map (+1) (
map (*2) [1, 2, 3]
)
)
-- substitute `foldr` definition for `map`:
printEach (
map (+1) (
foldr (\x acc -> x * 2 : acc) [] [1,2,3]
)
)
-- foldr's (x:xs) case matches, expression becomes:
printEach (
map (+1) (
1 * 2 : foldr (\x acc -> x * 2 : acc) [] [2, 3]
)
)
-- first `map` can pattern match now, so we expand to the
-- foldr definition:
printEach (
foldr (\x acc -> x + 1 : acc) [] (
1 * 2 : foldr (\x acc -> x * 2 : acc) [] [2, 3]
)
)
-- foldr matches on `(x:xs)`, so we evaluate that bit:
printEach (
1 * 2 + 1 : foldr (\x acc -> x + 1 : acc) [] (
foldr (\x acc -> x * 2 : acc) [] [2, 3]
)
)
-- printEach matches on (x:xs), so we can go to `print x`:
printEach (1 * 2 + 1 : xs) = do
print (1 * 2 + 1)
printEach xs
-- `print` does the match and prints it, and then recurses:
printEach (
foldr (\x acc -> x + 1 : acc) [] (
foldr (\x acc -> x * 2 : acc) [] [2, 3]
)
)
This process repeats, until eventually the inner foldr
yields an empty list, and then the outer foldr
yields an empty list, and then printEach
prints "Done."
to finish things off.