Rate Limiting Middleware Golang
Problem
Getting too many requests can be a problem for api servers. One client cans send more requests than your server could handle, which would cause problems for the other clients. Luckily, we can solve this problem pretty easily.
Basic Server
Lets start with a basic server. We’ll be using golang for this example and we’ll only need the standard library.
We have the initial boilerplate:
package main
import (
"context"
"fmt"
"log"
"net/http"
"sync"
"time"
)
func main() {
// Optionally parse flags here
if err := serve(); err != nil {
log.Fatalln(err)
}
}
As for the main serve
function, we can define it as follows:
func index(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%d\n", time.Now().Nanosecond())
}
func serve() error {
mux := http.NewServeMux()
mux.Handle("/", index)
log.Println("Listening on :8080")
return http.ListenAndServe(":8080", mux)
}
Here, we have a simple webserver that listens on port 8080
. Any request we
have to the /
path will return with the current time in nanoseconds.
We can test that by curling our server.
Start the server by running go run main.go
, then run something like this in
another terminal:
for i in {1..10}
do
echo "Curling"
curl http://localhost:8080
sleep 1
done
Then we should see the timestamp print to our terminal every seccond. But what if we want to only allow requests every 5 seconds?
Middleware
Golang has a concept called middleware, that we discussed
previously, but basically, these
functions take in an http.HandlerFunc
, and returns another. The classic
but simple example to get familiar with the concept of middleware is to inject
a logger into a middleware.
For example:
func Logging(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
fmt.Println(r)
next.ServeHTTP(w, r)
}
}
Then we can just wrap the index
call with the logger
like so:
func serve() error {
mux := http.NewServeMux()
mux.Handle("/", Logging(index))
log.Println("Listening on :8080")
return http.ListenAndServe(":8080", mux)
}
The main bit to focus on here is the next.ServeHTTP
part. The next
refers
to the index
func, so the functionality stays the same. Before the
next
call we log out the request. Rerunning the server and curl client, we
should see the server logs the request and the client logs the timestamp.
Cool. Now that we know how to create middleware, let’s handle the rate-limiting.
Rate Limiting
Let’s start with a function that looks like so:
func HttpDebounce(next http.HandlerFunc, d time.Duration) http.HandlerFunc {
var threshold time.Time
return func(w http.ResponseWriter, r *http.Request) {
if time.Now().Before(threshold) {
fmt.Println("Get bounced")
return
}
next.ServeHTTP(w, r)
threshold = time.Now().Add(d)
}
}
This one is a little more complex, so let’s step throught it.
First we declare a Time
variable that get’s initialized to sometime in the
past. Then, instead of handling next
directly, we capture it in a
closure. Our closure has the same
signature as the http.HandlerFunc
interface so we’re allowed to return the
closure. Now, after the return, the HttpHandler
will call our closure
instead of next
. Then, inside of our closure, we compare the current time to
the captured threshold
time. If it has not been sufficiently long after our
most recent call, we return before we call next
. If it has been enough time
after our most recent call, then (and only then) do we call next
. After
calling next
, we update threshold
to be some time in the future.
We would also need to wrap index
with our new middleware like so:
func serve() error {
mux := http.NewServeMux()
mux.Handle("/", Logging(HttpDebounce(index, 5*time.Second)))
log.Println("Listening on :8080")
return http.ListenAndServe(":8080", mux)
}
If we pass index
and a time (like 5seconds), we can see that the client only
logs the timestamp every 5 seconds.
Here we can see the final resulting code.
main.go
/**
* File: main.go
* Written by: Stephen M. Reaves
* Created on: Wed, 21 Jun, 2023
* Description: Example showing how to implement rate-limiting using middleware
in golang.
*/
package main
import (
"context"
"fmt"
"log"
"net/http"
"sync"
"time"
)
func main() {
// Optionally parse flags here
if err := serve(); err != nil {
log.Fatalln(err)
}
}
func HttpDebounce(next http.HandlerFunc, d time.Duration) http.HandlerFunc {
var threshold time.Time
return func(w http.ResponseWriter, r *http.Request) {
if time.Now().Before(threshold) {
fmt.Println("Get bounced")
return
}
next.ServeHTTP(w, r)
threshold = time.Now().Add(d)
}
}
func Logging(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
fmt.Println(r)
next.ServeHTTP(w, r)
}
}
func DefaultMiddleware(final http.HandlerFunc) http.HandlerFunc {
return Logging(HttpDebounce(final, 2*time.Second))
}
func index(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "%d\n", time.Now().Nanosecond())
}
func serve() error {
mux := http.NewServeMux()
mux.Handle("/", Logging(HttpDebounce(index, 5*time.Second)))
log.Println("Listening on :8080")
return http.ListenAndServe(":8080", mux)
}
example_curl.sh
#!/bin/bash
for i in {1..10}
do
echo "Curling"
curl http://localhost:8080
sleep 1
done