[SpringBoot] Spring’s AOP combined with Redis implements current limiting for SMS interfaces

Article directory

  • Preface
  • 1. Preparation work
  • 2. Current limiting annotation
  • 3. Customize or select redisTemplate
    • 1. Customize RedisTemplate (depending on needs, I use the second option)
    • 2. Use StringRedisTemplate directly
  • 4. Open lua script
  • 5.Annotation analysis
  • 6.Interface testing

Foreword

Scenario: In order to limit the number of visits to the SMS verification code interface and prevent being swiped, Aop and redis are combined to limit user flow based on user IP.

1. Preparation

First, we create a Spring Boot project, introduce Web and Redis dependencies, and consider that interface current limiting is generally marked through annotations, and annotations are parsed through AOP, so we also need to add AOP dependencies. The final dependencies are as follows :

 <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-aop</artifactId>
        </dependency>
         <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-data-redis</artifactId>
        </dependency>
      <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

Then prepare a Redis instance in advance. After our project is configured, we can directly configure the basic information of Redis, such as:

spring.redis.host=localhost
spring.redis.port=6379
spring.redis.password=123

2. Current limiting annotation

Next we create a current limiting annotation. We divide the current limiting into two situations:
1: Global current limit for the current interface. For example, the interface can be accessed 100 times in 1 minute.
2: Current limiting for a certain IP address. For example, an IP address can be accessed 100 times in 1 minute.

For these two situations, we create an enumeration class:

public enum LimitType {<!-- -->
    /**
     * Default policy restricts traffic globally
     */
    DEFAULT,
    /**
     * Limit traffic based on requester IP
     */
    IP
}

Next we create the current limiting annotation:

@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {<!-- -->
    /**
     * Current limiting key
     */
    String key() default "rate_limit:";

    /**
     * Current limiting time, unit seconds
     */
    int time() default 60;

    /**
     * Number of current limits
     */
    int count() default 100;

    /**
     * Current limiting type
     */
    LimitType limitType() default LimitType.DEFAULT;
}

The first parameter is the current limiting key. This is just a prefix. In the future, the complete key will be this prefix plus the complete path of the interface method, which together form the current limiting key. This key will be stored in Redis.

The other three parameters are easy to understand, so I won’t go into details.

Well, whichever interface needs to limit the current flow in the future, just add the @RateLimiter annotation on that interface, and then configure the relevant parameters.

3. Customize or select redisTemplate

1. Customize RedisTemplate (depending on the need, I use the second option)

In Spring Boot, we are actually more accustomed to using Spring Data Redis to operate Redis, but the default RedisTemplate has a small pitfall, that is, JdkSerializationRedisSerializer is used for serialization. I don’t know if my friends have noticed that you can use this serialization tool directly. In the future, the keys and values stored in Redis will have some extra prefixes for no reason, which may cause errors when you read them with commands.

For example, when storing, the key is name and the value is javaboy, but when you operate on the command line, get name cannot get the data you want. The reason is that after saving to redis, there are some more characters in front of the name. You can only continue to use RedisTemplate to read it out.

When we use Redis for current limiting, we will use Lua scripts. When using Lua scripts, the situation mentioned above will occur, so we need to modify the serialization scheme of RedisTemplate.

Modify the RedisTemplate serialization scheme. This configuration uses the serializer of jackson2JsonRedisSerializer (forgot to introduce dependencies). The code reference example is as follows:

@Configuration
public class RedisConfig {<!-- -->

    @Bean
    public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory connectionFactory) {<!-- -->
        RedisTemplate<Object, Object> redisTemplate = new RedisTemplate<>();
        redisTemplate.setConnectionFactory(connectionFactory);
        // Use Jackson2JsonRedisSerialize to replace the default serialization (the default is JDK serialization)
        Jackson2JsonRedisSerializer<Object> jackson2JsonRedisSerializer = new Jackson2JsonRedisSerializer<>(Object.class);
        ObjectMapper om = new ObjectMapper();
        om.setVisibility(PropertyAccessor.ALL, JsonAutoDetect.Visibility.ANY);
        om.enableDefaultTyping(ObjectMapper.DefaultTyping.NON_FINAL);
        jackson2JsonRedisSerializer.setObjectMapper(om);
        redisTemplate.setKeySerializer(jackson2JsonRedisSerializer);
        redisTemplate.setValueSerializer(jackson2JsonRedisSerializer);
        redisTemplate.setHashKeySerializer(jackson2JsonRedisSerializer);
        redisTemplate.setHashValueSerializer(jackson2JsonRedisSerializer);
        return redisTemplate;
    }
}

2. Use StringRedisTemplate directly

StringRedisTemplate is a template defined by Spring Data Redis for operating redis. It inherits redisTemplate and uses the string sequence serializer by default (that is, both key and value are stored in String. Customizing RedisTemplate is nothing more than defining which storage type. Serializer, the first one above is a serializer in Json form, there is no difference in implementing the current limiting function in this article).

  • The reason I choose this is that I am lazy and don’t need to configure it.
  • But please note that the incoming key and value must be converted to toString() first, otherwise an error will be reported.

4. Open lua script

  1. The purpose of Lua script operating Redis is to ensure the atomicity of multiple Redis operations. If one of the Redis operations goes wrong, an exception can be thrown to springboot for Redis rollback. If you don’t know how to use lua scripts, you can go to the Black Horse Redis section on lua at station B to make up for it.

2. The meaning of the script process is roughly as follows:

  • First, get the incoming key and the current limit count and time.
  • The value corresponding to this key is obtained through get. This value is the number of times this interface can be accessed within the current time window.
  • If it is the first visit, the result obtained at this time is nil, otherwise the result obtained should be a number, so the next step is to judge, if the result obtained is a number, and this number is greater than count, then It means that the traffic limit has been exceeded, then just return the query results directly.
  • If the result obtained is nil, it means it is the first access. At this time, the current key will be incremented by 1, and then an expiration time will be set.
  • Finally, just return the value that has been incremented by 1.
-- redis current limiting script
--key parameter
local key = KEYS[1]
--Number of current limits
local limitCount = tonumber(ARGV[1])
-- Current limiting time
local limitTime = tonumber(ARGV[2])
-- Get the current time
local currentCount = redis.call('get', key)
-- If get the current number of keys > limitCount, return the maximum value
if currentCount and tonumber(currentCount) > limitCount then
    return tonumber(currentCount)
end
-- key increments by 1
currentCount = redis.call("incr",key)
-- if key value == 1 sets the expiration time of the expiration period
if tonumber(currentCount) == 1 then
    redis.call("expire",key,limitTime)
end
-- Return the value of key
return tonumber(currentCount)

5. Annotation analysis

  • In springboot, remember to enable the aop annotation function in the main method (you won’t check it yourself)
  • The core code is as follows (ask gpt if you don’t understand)
  • The exceptions in the code are custom exceptions, which can be thrown or handled by your own exception class.
@Component
@Aspect
@Slf4j
public class RateLimiterAspect {<!-- -->

    @Resource
    private StringRedisTemplate stringRedisTemplate;


    @Resource
    private RedisScript<Long> limitScript;

    @Before("@annotation(rateLimiter)")
    public void doBefore(JoinPoint point, RateLimiter rateLimiter) {<!-- -->
        String key = rateLimiter.key();
        int time = rateLimiter.time();
        int count = rateLimiter.count();

        String combineKey = getCombineKey(rateLimiter, point);
        List<String> keys = Collections.singletonList(combineKey);
        try {<!-- -->
            Long number = stringRedisTemplate.execute(limitScript, keys, String.valueOf(count), String.valueOf(time));
            if (number==null || number.intValue() > count) {<!-- -->
                throw new BusinessException(ErrorCode.PARAMS_ERROR,"Access too frequent, please try again later");
            }
            log.info("Limit request '{}', current request '{}', cache key '{}'", count, number.intValue(), keys.get(0));
        } catch (ServiceException e) {<!-- -->
            throw e;
        } catch (Exception e) {<!-- -->
            throw new BusinessException(ErrorCode.SYSTEM_ERROR, "The system is busy, please try again later");
        }
    }

    /**
     * Get ip as key
     * @param rateLimiter
     * @param point
     * @return
     */
    public String getCombineKey(RateLimiter rateLimiter, JoinPoint point) {<!-- -->
        StringBuffer stringBuffer = new StringBuffer(rateLimiter.key());
        if (rateLimiter.limitType() == LimitType.IP) {<!-- -->
            stringBuffer.append(
                    IpUtils.getIpAddr(((ServletRequestAttributes) RequestContextHolder.currentRequestAttributes())
                            .getRequest()))
                    .append("-");
        }
        MethodSignature signature = (MethodSignature) point.getSignature();
        Method method = signature.getMethod();
        Class<?> targetClass = method.getDeclaringClass();
        stringBuffer.append(targetClass.getName()).append("-").append(method.getName());
        return stringBuffer.toString();
    }

}

Tool class for obtaining user IP based on HttpRequest (ask gpt if you don’t understand the details)

public class IpUtils {<!-- -->
    public static String getIpAddr(HttpServletRequest request) {<!-- -->
        String ipAddress = null;
        try {<!-- -->
            ipAddress = request.getHeader("x-forwarded-for");
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {<!-- -->
                ipAddress = request.getHeader("Proxy-Client-IP");
            }
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {<!-- -->
                ipAddress = request.getHeader("WL-Proxy-Client-IP");
            }
            if (ipAddress == null || ipAddress.length() == 0 || "unknown".equalsIgnoreCase(ipAddress)) {<!-- -->
                ipAddress = request.getRemoteAddr();
                if (ipAddress.equals("127.0.0.1")) {<!-- -->
                    // Get the IP configured on this machine based on the network card
                    try {<!-- -->
                        ipAddress = InetAddress.getLocalHost().getHostAddress();
                    } catch (UnknownHostException e) {<!-- -->
                        e.printStackTrace();
                    }
                }
            }
            // In the case of multiple proxies, the first IP is the real IP of the client, and multiple IPs are divided according to ','
            if (ipAddress != null) {<!-- -->
                if (ipAddress.contains(",")) {<!-- -->
                    return ipAddress.split(",")[0];
                } else {<!-- -->
                    return ipAddress;
                }
            } else {<!-- -->
                return "";
            }
        } catch (Exception e) {<!-- -->
            e.printStackTrace();
            return "";
        }
    }
}

6. Interface test

As follows: According to the user’s IP address, the interface can only be called once in 60 seconds.

 @GetMapping("/message")
    @RateLimiter(time = 60, count = 1, limitType = LimitType.IP)
    public BaseResponse<String> sendMessage(String phone,HttpServletRequest request) {<!-- -->
        if (StringUtils.isBlank(phone)) {<!-- -->
            throw new BusinessException(ErrorCode.PARAMS_ERROR);
        }
        boolean result = userVenueReservationService.sendMessage(phone,request);
        return ResultUtils.success(result ? "Send successfully" : "Send failed");
    }

Please correct if there are any mistakes~~