Implementing a GraphQL DataLoader the hard way
— 8 minLast week, I have written about the Dataloader pattern commonly used within GraphQL servers
to avoid 1 + N
problems that can pop up with a naive implementation of GraphQL resolvers.
You can read up on last weeks post here.
Last week, we established that the Dataloader pattern fundamentally takes advantage of async concurrency,
and we used yield_now
to split the loading of our data into distinct phases.
We stopped short of actually implementing the full DataLoader
the simple way, which we will continue with today.
Then later on, we will also implement it the hard way, by digging deeper and implementing our own Future
,
and messing around with Waker
s, and even throwing in some unsafe
code for good measure.
# The simple way
As a reminder, we ended up with the following trait for our generic BatchLoader
which is responsible for actually loading data in batches:
pub trait BatchLoader {
type K: Hash + Eq + Clone;
type V: Clone;
async fn load_batch(&mut self, keys: &[Self::K]) -> HashMap<Self::K, Self::V>;
}
An interesting observation here is that we chose to use a &mut self
. This gives us more flexibility when implementing the BatchLoader
,
but it also restricts us a bit when using it.
In particular, we have to use an async Mutex
when we are calling the load_batch
function.
We could have chosen to use &self
, in which case the choice of whether or not to use an async Mutex
would have been pushed down to the
implementor of the BatchLoader
.
With also another advantage of allowing to resolve multiple batches at the same time.
But we will actually continue to use our async Mutex
as the second building block alongside yield_now
to implement our DataLoader
.
We thus end up with the following implementation for our DataLoader
:
struct LoaderInner<B: BatchLoader> {
load_batch: B,
requested_keys: Vec<B::K>,
resolved_values: HashMap<B::K, B::V>,
}
pub struct DataLoader<B: BatchLoader> {
inner: Arc<Mutex<LoaderInner<B>>>,
}
impl<B: BatchLoader> DataLoader<B>
{
pub async fn load(&self, key: B::K) -> B::V {
// First, register the load request:
{
let mut inner = self.inner.lock().await;
println!("starting to resolve related objects for `{key:?}`");
inner.requested_keys.push(key.clone());
}
// Then yield, which will cause all the other `load` calls to also register their requests
yield_now().await;
// Next, load all of the `requested_keys` in a batch
let mut inner = self.inner.lock().await;
let requested_keys = std::mem::take(&mut inner.requested_keys);
if !requested_keys.is_empty() {
let resolved_values = inner.load_batch.load_batch(&requested_keys).await;
inner.resolved_values.extend(resolved_values);
}
// And last but not least, return the resolved value
inner
.resolved_values
.get(&key)
.expect("value should have been loaded")
.clone()
}
}
Running our GraphQL query with this implementation, we end up with the following log output, which highlights how our batch loading has been successful, end we ended up with a total of 3 queries, corresponding to the depth of our query, and not its fan-out at each level.
As a reminder, our query looks like { authors { name books { title summary } } }
.
// Loading authors
starting to resolve authors
finished resolving authors
// Loading books
starting to resolve related objects for `"Hermann Hesse"`
starting to resolve related objects for `"Thomas Mann"`
actually resolving Books by `Hermann Hesse`, `Thomas Mann`
finished resolving Books by `Hermann Hesse`, `Thomas Mann`
// Loading summaries
starting to resolve related objects for `"Siddhartha"`
starting to resolve related objects for `"Das Glasperlenspiel"`
starting to resolve related objects for `"Zauberberg"`
actually resolving summaries for `Siddhartha`, `Das Glasperlenspiel`, `Zauberberg`
finished resolving summaries for `Siddhartha`, `Das Glasperlenspiel`, `Zauberberg`
So all in all, we ended up with a very simple and straightforward implementation.
An implementation which is very close to the upstream dataloader
crate
in fact.
The only difference being that the dataloader
crate has more features around avoiding to load values we have already
loaded before, and limiting the batch size, among other things.
# The hard way
But I also promised to implement the dataloader pattern the hard way, which in this case means not relying on third-party
async building blocks like yield_now
and async Mutex
anymore (though still a sync Mutex
).
Instead, we will hook directly into the Future
lifecycle. Which means that we will implement our own Future
,
and poll
the underlying GraphQL resolvers manually.
The idea is that we will split up our implementation into two parts. The leaf async load
functions responsibility is
to register interest in a particular key, and to return the resolved value once that is ready.
We end up with this implementation:
pub fn load(&self, key: B::K) -> impl Future<Output = B::V> {
poll_fn(move |cx| {
let mut inner = self.inner.lock().unwrap();
// Check the resolved value, and return it if it was resolved already
if let Some(v) = inner.resolved_values.get(&key) {
return Poll::Ready(v.clone());
}
// Otherwise, register the requested key, and its `Waker`
println!("starting to resolve related objects for `{key:?}`");
inner.requested_keys.push(key.clone());
inner.pending_wakers.push(cx.waker().clone());
Poll::Pending
})
}
The idea here is that on the first call to poll
, we will add our key to requested_keys
, along with a Waker
,
and return Poll::Pending
immediately.
Returning Poll::Pending
is the equivalent of yield_now
above in the sense that it will yield control back to the caller,
so that async concurrency will poll
the next load
function to register its key as well.
But just returning Poll::Pending
is not enough, as the underlying async runtime, or some async concurrency primitives
like join_all
will just not poll
the future again unless it is awoken.
For this, we are also storing the Waker
associated with this poll
, so that we can wake
it once we have actually
loaded the data, so that our Future
is actually being poll
ed again.
So far so good, our load
function does not look too complex. Using std::future::poll_fn
was also an amazing help, as
it saved us from having to write our own struct and impl Future
implementation.
We are currently relying on -> impl Future
, which means the resulting future cannot be named, and it can thus not be stored
directly in a datastructure.
But you know what? We are implementing this the hard way, so lets do just that.
impl DataLoader<B: BatchLoader> {
pub fn load(&self, key: B::K) -> LoadFuture<B> {
LoadFuture { loader: self, key }
}
}
pub struct LoadFuture<'l, B: BatchLoader> {
loader: &'l DataLoader<B>,
key: B::K,
}
impl<'l, B: BatchLoader> Future for LoadFuture<'l, B>
where
B::K: Debug,
{
type Output = B::V;
fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let mut inner = self.loader.inner.lock().unwrap();
// Check the resolved value, and return it if it was resolved already
if let Some(v) = inner.resolved_values.get(&self.key) {
return Poll::Ready(v.clone());
}
// Otherwise, register the requested key, and its `Waker`
println!("starting to resolve related objects for `{:?}`", self.key);
inner.requested_keys.push(self.key.clone());
inner.pending_wakers.push(cx.waker().clone());
Poll::Pending
}
}
This wasn’t actually half that bad. We didn’t even need to resort to unsafe
code.
The reason is that we have accidentally created a future that can be safely poll
ed again after it returned Poll::Ready
.
You see, Future
s do not behave like this usually. Usually, futures are implemented as state machines, which internally
change to a finished state, and will intentionally panic
if poll
is called again in that state.
Changing that internal state behind a Pin<&mut Self>
however would most likely require usage of unsafe
or some external
dependency that implements pin projection, which is a very complex topic on its own that I won’t dive into right now.
Alright, we have one side of the equation implemented. But the other side is still missing.
We want to somehow hook into the async lifecycle. My idea for that was to wrap the whole GraphQL execution / future,
so we can manually control when to poll
it.
So our GraphQL endpoint / handler looks like this now:
async fn graphql_handler(State(schema): State<FullSchema>, req: GraphQLRequest) -> GraphQLResponse {
let load_books = DataLoader::new(LoadBooks);
let load_summaries = DataLoader::new(LoadSummaries);
let req = req
.into_inner()
.data(load_books.clone())
.data(load_summaries.clone());
let execute = schema.execute(req);
let wrapped = load_summaries.wrap(load_books.wrap(execute));
wrapped.await.into()
}
We add our DataLoader
s to the GraphQLRequest
context first, so we can use its load
part.
And then we wrap
the actual execution with our DataLoader
yet again for the second part.
A simplified version looks like this:
pub async fn wrap<O>(&self, fut: impl Future<Output = O>) -> O {
let mut currently_loading = None;
let mut fut = pin!(fut);
poll_fn(move |cx| {
// If we are currently loading a batch, poll that to completion first
if let Some(currently_loading_fut) = &mut currently_loading {
match currently_loading_fut.as_mut().poll(cx) {
Poll::Ready(resolved_values) => {
let mut inner = self.inner.lock().unwrap();
inner.resolved_values.extend(resolved_values);
// Wake all the `load` calls waiting on this batch
for waker in inner.pending_wakers.drain(..) {
waker.wake();
}
currently_loading = None;
}
Poll::Pending => return Poll::Pending,
}
}
let res = fut.as_mut().poll(cx);
if res.is_pending() {
// We have polled the inner future once, during which it may have registered more
// keys to load.
let mut inner = self.inner.lock().unwrap();
let requested_keys = std::mem::take(&mut inner.requested_keys);
if !requested_keys.is_empty() {
currently_loading = Some(inner.load_batch.load_batch(requested_keys));
// Wake immediately, to instruct the runtime to call `poll` again right away.
cx.waker().wake_by_ref();
}
}
res
})
.await
}
If we take a look at the second half of that function first, we can see that we are manually poll
ing the wrapped future,
and when it returns Poll::Pending
, we kick of our batch loading future.
We then have to immediately wake
the contexts Waker
, so that our runtime will immediately call our poll_fn
again.
On that second call, we end up in the first half of that function, which is responsible to poll
our back loading future.
Once that is complete, we will wake
all our previously registered Waker
s, and continue to the second part once more.
That second part wil again poll
our wrapped future, and may end up doing a second batch load.
Again, with the help of poll_fn
, we do need not worry about pin projection, and we can just pin!()
the wrapped
future to the stack so we can simply poll
it directly, just like in the
documentation example of poll_fn
.
But there is still one problem with this function as written: It does not compile as-is.
Unfortunately, the type inference in the compiler is not smart enough to infer the type of currently_loading
,
and it asks us to manually annotate it:
error[E0282]: type annotations needed for `std::option::Option<_>`
--> src\server\dataloader.rs:87:13
|
87 | let mut currently_loading = None;
| ^^^^^^^^^^^^^^^^^^^^^
...
92 | match currently_loading_fut.as_mut().poll(cx) {
| ------ type must be known at this point
|
help: consider giving `currently_loading` an explicit type, where the type for type parameter `T` is specified
|
87 | let mut currently_loading: std::option::Option<T> = None;
| ++++++++++++++++++++++++
Usually, this wouldn’t be a big deal. Except that we cannot name this type, as it comes from an async fn
.
An AFIT (async function in trait) to be precise.
So what should we do in this situation? Just slap a Box<dyn Future>
on it, you might say.
But that also won’t work, as now our GraphQL handler is not Send
anymore, which is a requirement of axum
.
So we actually need a Pin<Box<dyn Future + Send>>
. And with that, we run into another well known limitation of AFIT:
the Send
-bound problem. You see, it is currently not possible to express in stable Rust that my async fn load_batch
has to be Send
.
Unstable Rust offers the return type notation to allow expressing such bounds.
Another alternative would be to change the definition (and implementations) of our async fn load_batch
:
pub trait BatchLoader {
type K: Hash + Eq + Clone;
type V: Clone;
fn load_batch(
&mut self,
keys: Vec<Self::K>,
) -> impl Future<Output = HashMap<Self::K, Self::V>> + Send + 'static;
}
As a side note, we also had to change keys
to an owned Vec
instead of a slice because otherwise we would end up with
different borrow checker problems as well.
Restricting our BatchLoader
to be Send
(which is pretty much enforced by the usage of axum
indirectly anyway) and
'static
, we end up with an implementation that works, and is actually safe as well.
# Conclusion
To summarize, by building on the yield_now
and async Mutex
primitives, we can create our own very simple
implementation of the GraphQL DataLoader
pattern.
That implementation ends up very similar to the implementation of the
dataloader
crate.
If we want to drop one level of abstraction down, and rely only on functionality present in std
, we can also implement
the DataLoader
the hard way, by manually hooking into the poll
lifecycle of the GraphQL execution.
By changing and restricting the implementation of our BatchLoader
trait a bit, we can implement all that even in safe Rust.
It was a fun journey for me, and highlights that lower level (hard-mode) async Rust definitely still has some rough edges.
If you are curious of following all of my experimentation steps, and want to find out why I was teasing some unsafe
code
in the introduction, you can find all that on the companion GitHub repo.
Hint: I used unsafe
to create Send + 'static
bounds out of thin air instead of changing the BatchLoader
trait, because YOLO.