Refactored Telegram

It’s all just ones and zeros under the cover

A Simple Source IP Address Filter in Go

I’ve found that it’s occasionally useful to have something that allows through, or blocks, requests to your web application based on the source IP address. There are a number of reasons as to why you may want to do this: maybe it’s because you’d like to put something online that only you would have like access to, or it could be that you’re building something that is publicly available, but certain endpoints should only be accessible to certain machines for security or privacy reasons. For me, the motivation was to build something that was not quite ready to share with the outside world.

Either way, a simple IP Address filter might be a useful thing to keep in your toolkit. This article shows you how to build one.

A Pattern For Middleware

The filter will be implemented as middleware for a Go web-app. There are a few ways to build middleware in Go, but the pattern that I prefer is to implement it as a function that takes an upstream handler as an argument, and returns a new handler which wraps it:

func SourceIPFilter(upstreamHandler http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		upstreamHandler(w, r)
	})
}

Applying the middleware to a service handler is as simple as passing the handler itself to this new function:

func main() {
	var serviceHandler http.Handler = newServiceHandler()
	
	http.ListenAndServe(":8080", SourceIPFilter(serviceHandler))
}

Because this is using types defined in the http package, this way of building middleware provides the maximum level of flexibility that is available. This is true even when using a framework like Gin, as most of these frameworks usually have a way of allowing the use Go’s standard handler types.

Getting The Source IP Of A Request

The first thing that handler needs to do is get the origin IP address of the request. This is actually a bit more involved than might be first considered, so it may be a good idea to do this in a separate function.

func requestSourceIp(req *http.Request) (string, error) {
	// TODO
}

The simplest case is getting the IP address of the client. This is available to us from the RemoteAddr field http.Request struct. The value includes the port as well as the IP, se we can use net.SplitHostPort to discard it:

func requestSourceIp(req *http.Request) (string, error) {
	host, _, err := net.SplitHostPort(req.RemoteAddr)
	if err != nil {
		return "", err
	}
	
	return host, nil
}

This works if the Go application is accepting connections directly. However, it begins to break down as soon as mediators between the client and the Go application begin to appear. Some examples of these might be:

  • Reverse proxies, like Apache or Nginx, that will accept external connections and forward them to the Go application.
  • Load balancers1, which will route inbound connections amongst multiple instances of the application, and
  • CDNs, which will provide caching services and DDoS protection.

Each of these services may accept the incoming connection itself and will connect to your service using a separate connection, with a different source IP address. So using the RemoteAddr field won’t work here.

Fortunately, many of these proxies provide the source IP address as headers on the request. The standard approach is to do so by using the Forwarded header. This header will contain details of the request source as scene from the intermediaries, with the first one being the details of the original client. Each of these “forwarding elements” contain attributes of the forwarded request in the form of key-value pairs, such as the IP address that the request was forwarded for (for), whether the forwarded request was made using either HTTP or HTTPS (scheme), and the hostname of the forwarded request (host).

We’re only interested in the source IP address of the first forwarded request, so we will need to get the value of the for key-value pair of the first element of the Forwarded header value:

func requestSourceIp(req *http.Request) (string, error) {
	// Check the Forward header
	forwardedHeader := r.Header.Get("Forwarded")
	if forwardedHeader != "" {
		parts := strings.Split(forwardedHeader, ",")
		firstPart := strings.TrimSpace(parts[0])
		subParts := strings.Split(firstPart, ";")
		for _, part := range subParts {
			normalisedPart := strings.ToLower(strings.TrimSpace(part))
			if strings.HasPrefix(normalisedPart, "for=") {
				return normalisedPart[4:]
			}
		}
	}	

	// Check on the request
	host, _, err := net.SplitHostPort(req.RemoteAddr)
	if err != nil {
		return "", err
	}
	
	return host, nil
}

Note that the code above is checking for the presence of the header before checking the source IP address. This is because the presence of the header is indication that the request was proxied.

This should work for modern proxies. However, the Forwarded header is only a relatively recent addition, and prior to that, the de facto standard was to set the X-Forwareded-For header. This is a lot simpler than the Forwarded header; it’s only contains a list of IP addresses separated by commas. Similar to the Forwarded header, the first one is the IP address of the original client.

func requestSourceIp(req *http.Request) (string, error) {
	// Check the Forward header
	forwardedHeader := r.Header.Get("Forwarded")
	if forwardedHeader != "" {
		parts := strings.Split(forwardedHeader, ",")
		firstPart := strings.TrimSpace(parts[0])
		subParts := strings.Split(firstPart, ";")
		for _, part := range subParts {
			normalisedPart := strings.ToLower(strings.TrimSpace(part))
			if strings.HasPrefix(normalisedPart, "for=") {
				return normalisedPart[4:], nil
			}
		}
	}

	// Check the X-Forwarded-For header	
	xForwardedForHeader := r.Header.Get("X-Forwarded-For")
	if xForwardedForHeader != "" {
		parts := strings.Split(xForwardedForHeader, ",")
		firstPart := strings.TrimSpace(parts[0])
		return firstPart, nil
	}	
	
	// Check on the request
	host, _, err := net.SplitHostPort(req.RemoteAddr)
	if err != nil {
		return "", err
	}
	
	return host, nil
}

Building Out the Filter

We now have something that will find the source IP address of a request, whether or not it has been proxied. This can be added to our middleware as a function call.

func SourceIPFilter(upstreamHandler http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		sourceIp, err := requestSourceIp(r)
		if err != nil {
			http.Error(w, "Internal server error", http.StatusInternalServerError)
			return
		}
		
		upstreamHandler(w, r)
	})
}

The final piece can now be added, which is to configure the IP addresses that is permitted to access the service. This version deals with a single IP address that is passed into the middleware function as a parameter, but it should be relatively easy to extend this to deal with a set of permitted IP addresses:

func SourceIPFilter(allowedIpAddress string, upstreamHandler http.Handler) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		sourceIp, err := requestSourceIp(r)
		if err != nil {
			// Unfortunately it's possible to get an error
			http.Error(w, "Internal server error", http.StatusInternalServerError)
			return
		}
		
		if sourceIp != allowedIpAddress {
			// We can return 403 Forbidden here, but I prefer to return 404 Not Found to indicate
			// plausibly deny that there is something here.
			http.Error(w, "Not found", http.StatusNotFound)
			return
		}
		
		upstreamHandler(w, r)
	})
}

The IP address that is to be allow through would be the public one that we’re using. For those with regular ISP plans without a fixed public IP address, that can usually be found by running a web-search with the query “what is my IP address”.

It might also be a good idea to make the permitted IP address configurable by storing it in an environment variable. That way, when the IP address changes, there’s no need to do any code changes.

Doing this will also make it easy to disable the filter based on the environment variable’s value. For example, setting the environment variable to the empty string can indicate that the filter should be bypassed, allowing all public traffic access to the resource.

func SourceIPFilter(allowedIpAddress string, upstreamHandler http.Handler) http.Handler {
	if allowedIpAddress == "" {
		// We don't need the filter, so simply return the upstream handler
		return upstreamHandler
	}
	
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		sourceIp, err := requestSourceIp(r)
		if err != nil {
			// Unfortunately it's possible to get an error
			http.Error(w, "Internal server error", http.StatusInternalServerError)
			return
		}
		
		if sourceIp != allowedIpAddress {
			// We can return 403 Forbidden here, but I prefer to return 404 Not Found to indicate
			// plausibly deny that there is something here.
			http.Error(w, "Not found", http.StatusNotFound)
			return
		}
		
		upstreamHandler(w, r)
	})
}

func main() {
	var serviceHandler http.Handler = newServiceHandler()
	
	allowedIpAddress := os.Env("ALLOWED_IP_ADDRESS")
	http.ListenAndServe(":8080", SourceIPFilter(allowedIpAddress, serviceHandler))
}

That’s pretty much it. You now have a simple IP address filter that can be used to protect access to handlers based on the source IP address, even if requests are passed through load balancers or reverse proxies. I’ve found that this is one of those useful utilities that can be kept in an internal library, and pulled out when necessary. I hope you find this useful as well.


  1. This will depend on the actual load balancer that you use, and what layer of the ISO network stack they operate on. If you’re deploying your application to AWS, for example, Application Load Balancers will actually terminate the IP connection, which will change the source IP address; whereas, Network Load Balancers, will not. ↩︎