SpringBoot service interface current limiting-AOP token bucket

Generally speaking, a threshold can be calculated for the throughput of the system. In order to ensure the stable operation of the system, once this threshold is reached, it is necessary to limit the flow and take some measures to complete the purpose of limiting the flow. For example: delayed processing, rejected processing, or partial rejected processing, etc. Otherwise, it is easy to cause server downtime.

Common current limiting algorithm

  • counter current limit

The counter current limiting algorithm is the simplest and crudest solution, and it is mainly used to limit the total number of concurrency, such as the size of the database connection pool, the size of the thread pool, the number of concurrent interface accesses, etc., all use the counter algorithm.

For example: use AomicInteger to count the number of concurrent executions that are currently being executed. If the threshold value is exceeded, the request will be rejected directly, indicating that the system is busy.

  • Leaky Bucket Algorithm

The idea of the leaky bucket algorithm is very simple. We compare water to a request, and the leaky bucket to the limit of system processing capacity. Water enters the leaky bucket first, and the water in the leaky bucket flows out at a certain rate. When the outflow rate is lower than the inflow rate When , due to the limited capacity of the leaky bucket, the subsequent incoming water overflows directly (rejecting the request), so as to realize current limiting.

  • token bucket algorithm

The principle of the token bucket algorithm is also relatively simple. We can understand it as a hospital registration to see a doctor. Only after getting the number can the doctor be diagnosed.

The system will maintain a token (token) bucket, and put tokens (token) into the bucket at a constant speed. At this time, if a request comes in and wants to be processed, you need to get a token (token) from the bucket first. ), when there is no token available in the bucket, the request will be denied service. The token bucket algorithm limits requests by controlling the capacity of the bucket and the rate at which tokens are issued.
stand-alone mode

The Google open source toolkit Guava provides a ratelimiter tool class RateLimiter, which implements traffic limitation based on the token bucket algorithm, which is very convenient to use and very efficient

  • Import dependency pom
<dependency>
    <groupId>com.google.guava</groupId>
    <artifactId>guava</artifactId>
    <version>30.1-jre</version>
</dependency>
  • Create Annotation Limit
package com.example.demo.common.annotation;

import java.lang.annotation.*;
import java.util.concurrent.TimeUnit;

@Retention(RetentionPolicy. RUNTIME)
@Target({ElementType. METHOD})
@Documented
public @interface Limit {

    // resource key
    String key() default "";
    
    // maximum number of visits
    double permitsPerSecond();

    // time
    long timeout();
    
    // time type
    TimeUnit timeunit() default TimeUnit.MILLISECONDS;

    // Prompt message
    String msg() default "The system is busy, please try again later";

}
  • Annotation use

permitsPerSecond represents the total number of requests

timeout represents the time limit

That is, within the timeout time, only the total number of permitsPerSecond requests are allowed to access, and those exceeding the limit will be restricted and cannot be accessed

  • test

Start the project, fast read refresh access /cachingTest request

You can see that access has been successfully restricted

This method belongs to application-level current limiting. Assuming that the application is deployed to multiple machines, the application-level current limiting method only limits the request rate within a single application, and cannot perform global current limiting. Therefore, we need distributed current limiting and access layer current limiting to solve this problem.

distributed mode

Distributed current limiting based on redis + lua script

The key to distributed current limiting is to make the current limiting service atomic, and the solution can be implemented using redis + lua or nginx + lua technology, which can achieve high concurrency and high performance through these two technologies.

First of all, let’s use redis + lua to implement the current limit of the number of requests for an interface within the time window. After implementing this function, it can be modified to limit the total number of concurrent/requests and limit the total number of resources. Lua itself is a programming language, and it can also be used to implement complex token bucket or leaky bucket algorithms.

Because the operation is in a lua script (equivalent to an atomic operation), and because redis is a single-threaded model, it is thread-safe.

Compared with redis transactions, lua scripts have the following advantages:

Reduce network overhead: Code that does not use lua needs to send multiple requests to redis, while the script only needs to be sent once, reducing network transmission;
Atomic operation: redis executes the entire script as an atom, without worrying about concurrency, and therefore without transactions;
Reuse: The script will be permanently saved in redis, and other clients can continue to use it.

  • Create annotation RedisLimit
package com.example.demo.common.annotation;

import com.example.demo.common.enums.LimitType;

import java.lang.annotation.*;

@Target({ElementType. METHOD, ElementType. TYPE})
@Retention(RetentionPolicy. RUNTIME)
@Inherited
@Documented
public @interface RedisLimit {

    // Resource Name
    String name() default "";

    // resource key
    String key() default "";

    // prefix
    String prefix() default "";

    // time
    int period();

    // maximum number of visits
    int count();

    // type
    LimitType limitType() default LimitType. CUSTOMER;

    // Prompt message
    String msg() default "The system is busy, please try again later";

}
  • Annotate aop implementation
package com.example.demo.common.aspect;

import com.example.demo.common.annotation.RedisLimit;
import com.example.demo.common.enums.LimitType;
import com.example.demo.common.exception.LimitException;
import com.google.common.collect.ImmutableList;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;

import javax.servlet.http.HttpServletRequest;
import java.lang.reflect.Method;
import java.util.Objects;

@Slf4j
@Aspect
@Configuration
public class RedisLimitAspect {

    private final RedisTemplate<String, Object> redisTemplate;

    public RedisLimitAspect(RedisTemplate<String, Object> redisTemplate) {
        this.redisTemplate = redisTemplate;
    }

    @Around("@annotation(com.example.demo.common.annotation.RedisLimit)")
    public Object around(ProceedingJoinPoint pjp){
        MethodSignature methodSignature = (MethodSignature)pjp.getSignature();
        Method method = methodSignature. getMethod();
        RedisLimit annotation = method. getAnnotation(RedisLimit. class);
        LimitType limitType = annotation. limitType();

        String name = annotation. name();
        String key;

        int period = annotation. period();
        int count = annotation. count();

        switch (limitType){
            case IP:
                key = getIpAddress();
                break;
            case CUSTOMER:
                key = annotation. key();
                break;
            default:
                key = StringUtils. upperCase(method. getName());
        }
        ImmutableList<String> keys = ImmutableList.of(StringUtils.join(annotation.prefix(), key));
        try {
            String luaScript = buildLuaScript();
            DefaultRedisScript<Number> redisScript = new DefaultRedisScript<>(luaScript, Number.class);
            Number number = redisTemplate. execute(redisScript, keys, count, period);
            log.info("Access try count is {} for name = {} and key = {}", number, name, key);
            if(number != null & amp; & amp; number.intValue() == 1){
                return pjp. proceed();
            }
            throw new LimitException(annotation. msg());
        }catch (Throwable e){
            if(e instanceof LimitException){
                log.debug("token bucket={}, failed to get token",key);
                throw new LimitException(e. getLocalizedMessage());
            }
            e.printStackTrace();
            throw new RuntimeException("server exception");
        }
    }
 public String buildLuaScript(){
        return "redis. replicate_commands(); local listLen, time" +
                "\\
listLen = redis. call('LLEN', KEYS[1])" +
                // If it does not exceed the maximum value, write the time directly
                "\\
if listLen and tonumber(listLen) < tonumber(ARGV[1]) then" +
                "\\
local a = redis. call('TIME');" +
                "\\
redis.call('LPUSH', KEYS[1], a[1]*1000000 + a[2])" +
                "\\
else" +
                // Take the earliest existing time and compare it with the current time to see if it is smaller than the time interval
                "\\
time = redis. call('LINDEX', KEYS[1], -1)" +
                "\\
local a = redis. call('TIME');" +
                "\\
if a[1]*1000000 + a[2] - time < tonumber(ARGV[2])*1000000 then" +
                // The access frequency exceeds the limit, return 0 to indicate failure
                "\\
return 0;" +
                "\\
else" +
                "\\
redis.call('LPUSH', KEYS[1], a[1]*1000000 + a[2])" +
                "\\
redis.call('LTRIM', KEYS[1], 0, tonumber(ARGV[1])-1)" +
                "\\
end" +
                "\\
end" +
                "\\
return 1;";
    }

    public String getIpAddress(){
        HttpServletRequest request = ((ServletRequestAttributes) Objects. requireNonNull(RequestContextHolder. getRequestAttributes())). getRequest();
        String ip = request. getHeader("x-forwarded-for");
        if(ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)){
            ip = request.getHeader("Proxy-Client-IP");
        }
        if(ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)){
            ip = request.getHeader("WL-Client-IP");
        }
        if(ip == null || ip.length() == 0 || "unknown".equalsIgnoreCase(ip)){
            ip = request. getRemoteAddr();
        }
        return ip;
    }

}
  • Annotation use

count represents the total number of requests

period represents the time limit

That is, within the period time, only count requests are allowed to access the total number, and those exceeding will be restricted and cannot be accessed

package com.example.demo.module.test;

import com.example.demo.common.annotation.Limit;
import com.example.demo.common.annotation.RedisLimit;
import com.example.demo.common.dto.R;
import lombok.extern.slf4j.Slf4j;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;

import java.util.ArrayList;
import java.util.List;

@Slf4j
@RestController
public class TestController {

    @RedisLimit(key = "cachingTest", count = 2, period = 2, msg = "There are too many people in the queue, please try again later!")
// @Limit(key = "cachingTest", permitsPerSecond = 1, timeout = 500, msg = "There are too many people in the queue, please try again later!")
    @GetMapping("cachingTest")
    public R cachingTest(){
        log.info("------read local------");
        List<String> list = new ArrayList<>();
        list.add("Crayon Shin-Chan");
        list.add("Doraemon");
        list.add("Four-wheel drive brothers");

        return R.ok(list);
    }

}
  • test

Start the project, fast read refresh access /cachingTest request