This blog post is the first in a planned series I'm calling "Rust quickies." In my training sessions, we often come up with quick examples to demonstrate some point. Instead of forgetting about them, I want to put short blog posts together focusing on these examples. Hopefully these will be helpful, enjoy!
FP Complete is looking for Rust and DevOps engineers. Interested in working with us?
Check out our jobs page.
Short circuiting a for
loop
Let's say I've got an Iterator
of u32
s. I want to double each value and print it. Easy enough:
fn weird_function(iter: impl IntoIterator<Item=u32>) {
for x in iter.into_iter().map(|x| x * 2) {
println!("{}", x);
}
}
fn main() {
weird_function(1..10);
}
And now let's say we hate the number 8, and want to stop when we hit it. That's a simple one-line change:
fn weird_function(iter: impl IntoIterator<Item=u32>) {
for x in iter.into_iter().map(|x| x * 2) {
if x == 8 { return } // added this line
println!("{}", x);
}
}
Easy, done, end of story. And for this reason, I recommend using for
loops when possible. Even though, from a functional programming background, it feels overly imperative. However, some people out there want to be more functional, so let's explore that.
for_each vs map
Let's forget about the short-circuiting for a moment. And now we want to go back to the original version of the program, but without using a for
loop. Easy enough with the method for_each
. It takes a closure, which it runs for each value in the Iterator
. Let's check it out:
fn weird_function(iter: impl IntoIterator<Item=u32>) {
iter.into_iter().map(|x| x * 2).for_each(|x| {
println!("{}", x);
})
}
But why, exactly do we need for_each
? That seems awfully similar to map
, which also applies a function over every value in an Iterator
. Trying to make that change, however, demonstrates the problem. With this code:
fn weird_function(iter: impl IntoIterator<Item=u32>) {
iter.into_iter().map(|x| x * 2).map(|x| {
println!("{}", x);
})
}
we get an error message:
error[E0308]: mismatched types
--> src\main.rs:2:5
|
2 | / iter.into_iter().map(|x| x * 2).map(|x| {
3 | | println!("{}", x);
4 | | })
| |______^ expected `()`, found struct `Map`
Undaunted, I fix this error by sticking a semicolon at the end of that expression. That generates a warning of unused `Map` that must be used
. And sure enough, running this program produces no output.
The problem is that map
doesn't drain the Iterator
. Said another way, map
is lazy. It adapts one Iterator
into a new Iterator
. But unless something comes along and drains or forces the Iterator
, no actions will occur. By contrast, for_each
will always drain an Iterator
.
One easy trick to force draining of an Iterator
is with the count()
method. This will perform some unnecessary work of counting how many values are in the Iterator
, but it's not that expensive. Another approach would be to use collect
. This one is a little trickier, since collect
typically needs some type annotations. But thanks to a fun trick of how FromIterator
is implemented for the unit type, we can collect a stream of ()
s into a single ()
value. Meaning, this code works:
fn weird_function(iter: impl IntoIterator<Item=u32>) {
iter.into_iter().map(|x| x * 2).map(|x| {
println!("{}", x);
}).collect()
}
Note the lack of a semicolon at the end there. What do you think will happen if we add in the semicolon?
Short circuiting
EDIT Enough people have asked "why not use take_while
?" that I thought I'd address it. Yes, below, take_while
will work for "short circuiting." It's probably even a good idea. But the main goal in this post is to explore some funny implementation approaches, not recommend a best practice. And overall, despite some good arguments for take_while
being a good choice here, I still stand by the overall recommendation to prefer for
loops for simplicity.
With the for
loop approach, stopping at the first 8 was a trivial, 1 line addition. Let's do the same thing here:
fn weird_function(iter: impl IntoIterator<Item=u32>) {
iter.into_iter().map(|x| x * 2).map(|x| {
if x == 8 { return }
println!("{}", x);
}).collect()
}
Take a guess at what the output will be. Ready? OK, here's the real thing:
2
4
6
10
12
14
16
18
We skipped 8, but we didn't stop. It's the difference between a continue
and a break
inside the for
loop. Why did this happen?
It's important to think about the scope of a return
. It will exit the current function. And in this case, the current function isn't weird_function
, but the closure inside the map
call. This is what makes short-circuiting inside map
so difficult.
The same exact comment will apply to for_each
. The only way to stop a for_each
from continuing is to panic (or abort the program, if you want to get really aggressive).
But with map
, we have some ingenious ways of working around this and short-circuiting. Let's see it in action.
collect an Option
map
needs some draining method to drive it. We've been using collect
. I've previously discussed the intricacies of this method. One cool feature of collect
is that, for Option
and Result
, it provides short-circuit capabilities. We can modify our program to take advantage of that:
fn weird_function(iter: impl IntoIterator<Item=u32>) -> Option<()> {
iter.into_iter().map(|x| x * 2).map(|x| {
if x == 8 { return None } // short circuit!
println!("{}", x);
Some(()) // keep going!
}).collect()
}
I put a return type of weird_function
, though we could also use turbofish on collect
and throw away the result. We just need some type annotation to say what we're trying to collect. Since collecting the underlying ()
values doesn't take up extra memory, this is even pretty efficient! The only cost is the extra Option
. But that extra Option
is (arguably) useful; it lets us know if we short-circuited or not.
But the story isn't so rosy with other types. Let's say our closure within map
returns the x
value. In other words, replace the last line with Some(x)
instead of Some(())
. Now we need to somehow collect up those u32
s. Something like this would work:
fn weird_function(iter: impl IntoIterator<Item=u32>) -> Option<Vec<u32>> {
iter.into_iter().map(|x| x * 2).map(|x| {
if x == 8 { return None } // short circuit!
println!("{}", x);
Some(x) // keep going!
}).collect()
}
But that incurs a heap allocation that we don't want! And using count()
from before is useless too, since it won't even short circuit.
But we do have one other trick.
sum
It turns out there's another draining method on Iterator
that performs short circuiting: sum
. This program works perfectly well:
fn weird_function(iter: impl IntoIterator<Item=u32>) -> Option<u32> {
iter.into_iter().map(|x| x * 2).map(|x| {
if x == 8 { return None } // short circuit!
println!("{}", x);
Some(x) // keep going!
}).sum()
}
The downside is that it's unnecessarily summing up the values. And maybe that could be a real problem if some kind of overflow occurs. But this mostly works. But is there some way we can stay functional, short circuit, and get no performance overhead? Sure!
Short
The final trick here is to create a new helper type for summing up an Iterator
. But this thing won't really sum. Instead, it will throw away all of the values, and stop as soon as it sees an Option
. Let's see it in practice:
#[derive(Debug)]
enum Short {
Stopped,
Completed,
}
impl<T> std::iter::Sum<Option<T>> for Short {
fn sum<I: Iterator<Item = Option<T>>>(iter: I) -> Self {
for x in iter {
if let None = x { return Short::Stopped }
}
Short::Completed
}
}
fn weird_function(iter: impl IntoIterator<Item=u32>) -> Short {
iter.into_iter().map(|x| x * 2).map(|x| {
if x == 8 { return None } // short circuit!
println!("{}", x);
Some(x) // keep going!
}).sum()
}
fn main() {
println!("{:?}", weird_function(1..10));
}
And voila! We're done!
Exercise It's pretty cheeky to use sum
here. collect
makes more sense. Replace sum
with collect
, and then change the Sum
implementation into something else. Solution at the end.
Conclusion
That's a lot of work to be functional. Rust has a great story around short circuiting. And it's not just with return
, break
, and continue
. It's with the ?
try operator, which forms the basis of error handling in Rust. There are times when you'll want to use Iterator
adapters, async streaming adapters, and functional-style code. But unless you have a pressing need, my recommendation is to stick to for
loops.
If you liked this post, and would like to see more Rust quickies, let me know. You may also like these other pages:
Solution
use std::iter::FromIterator;
#[derive(Debug)]
enum Short {
Stopped,
Completed,
}
impl<T> FromIterator<Option<T>> for Short {
fn from_iter<I: IntoIterator<Item = Option<T>>>(iter: I) -> Self {
for x in iter {
if let None = x { return Short::Stopped }
}
Short::Completed
}
}
fn weird_function(iter: impl IntoIterator<Item=u32>) -> Short {
iter.into_iter().map(|x| x * 2).map(|x| {
if x == 8 { return None } // short circuit!
println!("{}", x);
Some(x) // keep going!
}).collect()
}
fn main() {
println!("{:?}", weird_function(1..10));
}
Subscribe to our blog via email
Email subscriptions come from our Atom feed and are handled by Blogtrottr. You will only receive notifications of blog posts, and can unsubscribe any time.
Do you like this blog post and need help with Next Generation Software Engineering, Platform Engineering or Blockchain & Smart Contracts? Contact us.