Swatinem Blog Resume

Implementing a GraphQL DataLoader the hard way

— 8 min

Last week, I have written about the Dataloader pattern commonly used within GraphQL servers to avoid 1 + N problems that can pop up with a naive implementation of GraphQL resolvers.

You can read up on last weeks post here.


Last week, we established that the Dataloader pattern fundamentally takes advantage of async concurrency, and we used yield_now to split the loading of our data into distinct phases.

We stopped short of actually implementing the full DataLoader the simple way, which we will continue with today. Then later on, we will also implement it the hard way, by digging deeper and implementing our own Future, and messing around with Wakers, and even throwing in some unsafe code for good measure.

# The simple way

As a reminder, we ended up with the following trait for our generic BatchLoader which is responsible for actually loading data in batches:

pub trait BatchLoader {
    type K: Hash + Eq + Clone;
    type V: Clone;

    async fn load_batch(&mut self, keys: &[Self::K]) -> HashMap<Self::K, Self::V>;
}

An interesting observation here is that we chose to use a &mut self. This gives us more flexibility when implementing the BatchLoader, but it also restricts us a bit when using it. In particular, we have to use an async Mutex when we are calling the load_batch function.

We could have chosen to use &self, in which case the choice of whether or not to use an async Mutex would have been pushed down to the implementor of the BatchLoader. With also another advantage of allowing to resolve multiple batches at the same time.

But we will actually continue to use our async Mutex as the second building block alongside yield_now to implement our DataLoader.

We thus end up with the following implementation for our DataLoader:

struct LoaderInner<B: BatchLoader> {
    load_batch: B,
    requested_keys: Vec<B::K>,
    resolved_values: HashMap<B::K, B::V>,
}

pub struct DataLoader<B: BatchLoader> {
    inner: Arc<Mutex<LoaderInner<B>>>,
}

impl<B: BatchLoader> DataLoader<B>
{
    pub async fn load(&self, key: B::K) -> B::V {
        // First, register the load request:
        {
            let mut inner = self.inner.lock().await;
            println!("starting to resolve related objects for `{key:?}`");
            inner.requested_keys.push(key.clone());
        }

        // Then yield, which will cause all the other `load` calls to also register their requests
        yield_now().await;

        // Next, load all of the `requested_keys` in a batch
        let mut inner = self.inner.lock().await;
        let requested_keys = std::mem::take(&mut inner.requested_keys);
        if !requested_keys.is_empty() {
            let resolved_values = inner.load_batch.load_batch(&requested_keys).await;
            inner.resolved_values.extend(resolved_values);
        }

        // And last but not least, return the resolved value
        inner
            .resolved_values
            .get(&key)
            .expect("value should have been loaded")
            .clone()
    }
}

Running our GraphQL query with this implementation, we end up with the following log output, which highlights how our batch loading has been successful, end we ended up with a total of 3 queries, corresponding to the depth of our query, and not its fan-out at each level.

As a reminder, our query looks like { authors { name books { title summary } } }.

// Loading authors
starting to resolve authors
finished resolving authors

// Loading books
starting to resolve related objects for `"Hermann Hesse"`
starting to resolve related objects for `"Thomas Mann"`
actually resolving Books by `Hermann Hesse`, `Thomas Mann`
finished resolving Books by `Hermann Hesse`, `Thomas Mann`

// Loading summaries
starting to resolve related objects for `"Siddhartha"`
starting to resolve related objects for `"Das Glasperlenspiel"`
starting to resolve related objects for `"Zauberberg"`
actually resolving summaries for `Siddhartha`, `Das Glasperlenspiel`, `Zauberberg`
finished resolving summaries for `Siddhartha`, `Das Glasperlenspiel`, `Zauberberg`

So all in all, we ended up with a very simple and straightforward implementation. An implementation which is very close to the upstream dataloader crate in fact.

The only difference being that the dataloader crate has more features around avoiding to load values we have already loaded before, and limiting the batch size, among other things.

# The hard way

But I also promised to implement the dataloader pattern the hard way, which in this case means not relying on third-party async building blocks like yield_now and async Mutex anymore (though still a sync Mutex).

Instead, we will hook directly into the Future lifecycle. Which means that we will implement our own Future, and poll the underlying GraphQL resolvers manually.

The idea is that we will split up our implementation into two parts. The leaf async load functions responsibility is to register interest in a particular key, and to return the resolved value once that is ready.

We end up with this implementation:

pub fn load(&self, key: B::K) -> impl Future<Output = B::V> {
    poll_fn(move |cx| {
        let mut inner = self.inner.lock().unwrap();

        // Check the resolved value, and return it if it was resolved already
        if let Some(v) = inner.resolved_values.get(&key) {
            return Poll::Ready(v.clone());
        }

        // Otherwise, register the requested key, and its `Waker`
        println!("starting to resolve related objects for `{key:?}`");
        inner.requested_keys.push(key.clone());
        inner.pending_wakers.push(cx.waker().clone());
        Poll::Pending
    })
}

The idea here is that on the first call to poll, we will add our key to requested_keys, along with a Waker, and return Poll::Pending immediately.

Returning Poll::Pending is the equivalent of yield_now above in the sense that it will yield control back to the caller, so that async concurrency will poll the next load function to register its key as well.

But just returning Poll::Pending is not enough, as the underlying async runtime, or some async concurrency primitives like join_all will just not poll the future again unless it is awoken.

For this, we are also storing the Waker associated with this poll, so that we can wake it once we have actually loaded the data, so that our Future is actually being polled again.

So far so good, our load function does not look too complex. Using std::future::poll_fn was also an amazing help, as it saved us from having to write our own struct and impl Future implementation.

We are currently relying on -> impl Future, which means the resulting future cannot be named, and it can thus not be stored directly in a datastructure.

But you know what? We are implementing this the hard way, so lets do just that.

impl DataLoader<B: BatchLoader> {
    pub fn load(&self, key: B::K) -> LoadFuture<B> {
        LoadFuture { loader: self, key }
    }
}

pub struct LoadFuture<'l, B: BatchLoader> {
    loader: &'l DataLoader<B>,
    key: B::K,
}

impl<'l, B: BatchLoader> Future for LoadFuture<'l, B>
where
    B::K: Debug,
{
    type Output = B::V;

    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
        let mut inner = self.loader.inner.lock().unwrap();

        // Check the resolved value, and return it if it was resolved already
        if let Some(v) = inner.resolved_values.get(&self.key) {
            return Poll::Ready(v.clone());
        }

        // Otherwise, register the requested key, and its `Waker`
        println!("starting to resolve related objects for `{:?}`", self.key);
        inner.requested_keys.push(self.key.clone());
        inner.pending_wakers.push(cx.waker().clone());
        Poll::Pending
    }
}

This wasn’t actually half that bad. We didn’t even need to resort to unsafe code. The reason is that we have accidentally created a future that can be safely polled again after it returned Poll::Ready.

You see, Futures do not behave like this usually. Usually, futures are implemented as state machines, which internally change to a finished state, and will intentionally panic if poll is called again in that state.

Changing that internal state behind a Pin<&mut Self> however would most likely require usage of unsafe or some external dependency that implements pin projection, which is a very complex topic on its own that I won’t dive into right now.


Alright, we have one side of the equation implemented. But the other side is still missing. We want to somehow hook into the async lifecycle. My idea for that was to wrap the whole GraphQL execution / future, so we can manually control when to poll it.

So our GraphQL endpoint / handler looks like this now:

async fn graphql_handler(State(schema): State<FullSchema>, req: GraphQLRequest) -> GraphQLResponse {
    let load_books = DataLoader::new(LoadBooks);
    let load_summaries = DataLoader::new(LoadSummaries);
    let req = req
        .into_inner()
        .data(load_books.clone())
        .data(load_summaries.clone());

    let execute = schema.execute(req);
    let wrapped = load_summaries.wrap(load_books.wrap(execute));
    wrapped.await.into()
}

We add our DataLoaders to the GraphQLRequest context first, so we can use its load part.

And then we wrap the actual execution with our DataLoader yet again for the second part.

A simplified version looks like this:

pub async fn wrap<O>(&self, fut: impl Future<Output = O>) -> O {
    let mut currently_loading = None;

    let mut fut = pin!(fut);
    poll_fn(move |cx| {
        // If we are currently loading a batch, poll that to completion first
        if let Some(currently_loading_fut) = &mut currently_loading {
            match currently_loading_fut.as_mut().poll(cx) {
                Poll::Ready(resolved_values) => {
                    let mut inner = self.inner.lock().unwrap();

                    inner.resolved_values.extend(resolved_values);

                    // Wake all the `load` calls waiting on this batch
                    for waker in inner.pending_wakers.drain(..) {
                        waker.wake();
                    }

                    currently_loading = None;
                }
                Poll::Pending => return Poll::Pending,
            }
        }

        let res = fut.as_mut().poll(cx);
        if res.is_pending() {
            // We have polled the inner future once, during which it may have registered more
            // keys to load.
            let mut inner = self.inner.lock().unwrap();

            let requested_keys = std::mem::take(&mut inner.requested_keys);
            if !requested_keys.is_empty() {
                currently_loading = Some(inner.load_batch.load_batch(requested_keys));

                // Wake immediately, to instruct the runtime to call `poll` again right away.
                cx.waker().wake_by_ref();
            }
        }
        res
    })
    .await
}

If we take a look at the second half of that function first, we can see that we are manually polling the wrapped future, and when it returns Poll::Pending, we kick of our batch loading future.

We then have to immediately wake the contexts Waker, so that our runtime will immediately call our poll_fn again.

On that second call, we end up in the first half of that function, which is responsible to poll our back loading future. Once that is complete, we will wake all our previously registered Wakers, and continue to the second part once more.

That second part wil again poll our wrapped future, and may end up doing a second batch load.

Again, with the help of poll_fn, we do need not worry about pin projection, and we can just pin!() the wrapped future to the stack so we can simply poll it directly, just like in the documentation example of poll_fn.


But there is still one problem with this function as written: It does not compile as-is.

Unfortunately, the type inference in the compiler is not smart enough to infer the type of currently_loading, and it asks us to manually annotate it:

error[E0282]: type annotations needed for `std::option::Option<_>`
  --> src\server\dataloader.rs:87:13
   |
87 |         let mut currently_loading = None;
   |             ^^^^^^^^^^^^^^^^^^^^^
...
92 |                 match currently_loading_fut.as_mut().poll(cx) {
   |                                             ------ type must be known at this point
   |
help: consider giving `currently_loading` an explicit type, where the type for type parameter `T` is specified
   |
87 |         let mut currently_loading: std::option::Option<T> = None;
   |                                  ++++++++++++++++++++++++

Usually, this wouldn’t be a big deal. Except that we cannot name this type, as it comes from an async fn. An AFIT (async function in trait) to be precise.

So what should we do in this situation? Just slap a Box<dyn Future> on it, you might say.

But that also won’t work, as now our GraphQL handler is not Send anymore, which is a requirement of axum.

So we actually need a Pin<Box<dyn Future + Send>>. And with that, we run into another well known limitation of AFIT: the Send-bound problem. You see, it is currently not possible to express in stable Rust that my async fn load_batch has to be Send.

Unstable Rust offers the return type notation to allow expressing such bounds.

Another alternative would be to change the definition (and implementations) of our async fn load_batch:

pub trait BatchLoader {
    type K: Hash + Eq + Clone;
    type V: Clone;

    fn load_batch(
        &mut self,
        keys: Vec<Self::K>,
    ) -> impl Future<Output = HashMap<Self::K, Self::V>> + Send + 'static;
}

As a side note, we also had to change keys to an owned Vec instead of a slice because otherwise we would end up with different borrow checker problems as well.

Restricting our BatchLoader to be Send (which is pretty much enforced by the usage of axum indirectly anyway) and 'static, we end up with an implementation that works, and is actually safe as well.

# Conclusion

To summarize, by building on the yield_now and async Mutex primitives, we can create our own very simple implementation of the GraphQL DataLoader pattern.

That implementation ends up very similar to the implementation of the dataloader crate.

If we want to drop one level of abstraction down, and rely only on functionality present in std, we can also implement the DataLoader the hard way, by manually hooking into the poll lifecycle of the GraphQL execution.

By changing and restricting the implementation of our BatchLoader trait a bit, we can implement all that even in safe Rust.

It was a fun journey for me, and highlights that lower level (hard-mode) async Rust definitely still has some rough edges.

If you are curious of following all of my experimentation steps, and want to find out why I was teasing some unsafe code in the introduction, you can find all that on the companion GitHub repo.

Hint: I used unsafe to create Send + 'static bounds out of thin air instead of changing the BatchLoader trait, because YOLO.