springweb flux intercepts requests to obtain parameters and methods for interface signature anti-replay verification

When doing interface signature and anti-replay for spring webflux, it is often necessary to obtain request parameters, request methods, etc., but spring webflux cannot be obtained as easily as spring mvc. Here is a special explanation based on previous practice:

general idea:
1. Use filters to obtain information from the original request, cache it in a context object, then construct a new request and pass it to the subsequent filter. Because the original request is streaming, parameters cannot be retrieved after being used once.
2. Pass the context object through the Attributes of exchange and use it in different filters.

1. Context object

@Getter
@Setter
@ToString
public class GatewayContext {<!-- -->

    public static final String CACHE_GATEWAY_CONTEXT = "cacheGatewayContext";

    /**
     * cache requestMethod
     */
    private String requestMethod;

    /**
     * cache queryParams
     */
    private MultiValueMap<String, String> queryParams;

    /**
     * cache json body
     */
    private String requestBody;
    /**
     * cache Response Body
     */
    private Object responseBody;
    /**
     * request headers
     */
    private HttpHeaders requestHeaders;
    /**
     * cache form data
     */
    private MultiValueMap<String, String> formData;
    /**
     * cache all request data include:form data and query param
     */
    private MultiValueMap<String, String> allRequestData = new LinkedMultiValueMap<>(0);

    private byte[] requestBodyBytes;

}

2. Get the request parameters and request method in the filter.
Here we only intercept body parameters for application/json and application/x-www-form-urlencoded. For other requests, they can be obtained directly through the url. to query parameters.

@Slf4j
@Component
public class GatewayContextFilter implements WebFilter, Ordered {<!-- -->

    /**
     * default HttpMessageReader
     */
    private static final List<HttpMessageReader<?>> MESSAGE_READERS = HandlerStrategies.withDefaults().messageReaders();


    @Override
    public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {<!-- -->
        ServerHttpRequest request = exchange.getRequest();
        GatewayContext gatewayContext = new GatewayContext();
        HttpHeaders headers = request.getHeaders();
        gatewayContext.setRequestHeaders(headers);
        gatewayContext.getAllRequestData().addAll(request.getQueryParams());
        gatewayContext.setRequestMethod(request.getMethodValue().toUpperCase());
        gatewayContext.setQueryParams(request.getQueryParams());
        /*
         * save gateway context into exchange
         */
        exchange.getAttributes().put(GatewayContext.CACHE_GATEWAY_CONTEXT, gatewayContext);
        MediaType contentType = headers.getContentType();
        if (headers.getContentLength() > 0) {<!-- -->
            if (MediaType.APPLICATION_JSON.equals(contentType)) {<!-- -->
                return readBody(exchange, chain, gatewayContext);

            }
            if (MediaType.APPLICATION_FORM_URLENCODED.equalsTypeAndSubtype(contentType)) {<!-- -->
                return readFormData(exchange, chain, gatewayContext);
            }
        }

        String path = request.getPath().value();
        if (!"/".equals(path)) {<!-- -->
            log.info("{} Gateway context is set with {}-{}", path, contentType, gatewayContext);
        }
        return chain.filter(exchange);
    }


    @Override
    public int getOrder() {<!-- -->
        return Integer.MIN_VALUE + 1;
    }


    /**
     * ReadFormData
     */
    private Mono<Void> readFormData(ServerWebExchange exchange, WebFilterChain chain, GatewayContext gatewayContext) {<!-- -->
        HttpHeaders headers = exchange.getRequest().getHeaders();
        return exchange.getFormData()
                .doOnNext(multiValueMap -> {<!-- -->
                    gatewayContext.setFormData(multiValueMap);
                    gatewayContext.getAllRequestData().addAll(multiValueMap);
                    log.debug("[GatewayContext]Read FormData Success");
                })
                .then(Mono.defer(() -> {<!-- -->
                    Charset charset = headers.getContentType().getCharset();
                    charset = charset == null ? StandardCharsets.UTF_8 : charset;
                    String charsetName = charset.name();
                    MultiValueMap<String, String> formData = gatewayContext.getFormData();
                    /*
                     * formData is empty just return
                     */
                    if (null == formData || formData.isEmpty()) {<!-- -->
                        return chain.filter(exchange);
                    }
                    log.info("1. Gateway Context formData: {}", formData);
                    StringBuilder formDataBodyBuilder = new StringBuilder();
                    String entryKey;
                    List<String> entryValue;
                    try {<!-- -->
                        /*
                         * repackage form data
                         */
                        for (Map.Entry<String, List<String>> entry : formData.entrySet()) {<!-- -->
                            entryKey = entry.getKey();
                            entryValue = entry.getValue();
                            if (entryValue.size() > 1) {<!-- -->
                                for (String value : entryValue) {<!-- -->
                                    formDataBodyBuilder
                                            .append(URLEncoder.encode(entryKey, charsetName).replace(" + ", " ").replace("*", "*").replace("~", \ "~"))
                                            .append("=")
                                            .append(URLEncoder.encode(value, charsetName).replace(" + ", " ").replace("*", "*").replace("~", \ "~"))
                                            .append(" & amp;");
                                }
                            } else {<!-- -->
                                formDataBodyBuilder
                                        .append(URLEncoder.encode(entryKey, charsetName).replace(" + ", " ").replace("*", "*").replace("~", \ "~"))
                                        .append("=")
                                        .append(URLEncoder.encode(entryValue.get(0), charsetName).replace(" + ", " ").replace("*", "*").replace(" ~", "~"))
                                        .append(" & amp;");
                            }
                        }
                    } catch (UnsupportedEncodingException e) {<!-- -->
                        log.error("GatewayContext readFormData error {}", e.getMessage(), e);
                    }
                    /*
                     * 1. substring with the last char ' & amp;'
                     * 2. if the current request is encrypted, substring with the start chat 'secFormData'
                     */
                    String formDataBodyString = "";
                    String originalFormDataBodyString = "";
                    if (formDataBodyBuilder.length() > 0) {<!-- -->
                        formDataBodyString = formDataBodyBuilder.substring(0, formDataBodyBuilder.length() - 1);
                        originalFormDataBodyString = formDataBodyString;
                    }
                    /*
                     * get data bytes
                     */
                    byte[] bodyBytes = formDataBodyString.getBytes(charset);
                    int contentLength = bodyBytes.length;
                    gatewayContext.setRequestBodyBytes(originalFormDataBodyString.getBytes(charset));
                    HttpHeaders httpHeaders = new HttpHeaders();
                    httpHeaders.putAll(exchange.getRequest().getHeaders());
                    httpHeaders.remove(HttpHeaders.CONTENT_LENGTH);
                    /*
                     * in case of content-length not matched
                     */
                    httpHeaders.setContentLength(contentLength);
                    /*
                     * use BodyInserter to InsertFormData Body
                     */
                    BodyInserter<String, ReactiveHttpOutputMessage> bodyInserter = BodyInserters.fromObject(formDataBodyString);
                    CachedBodyOutputMessage cachedBodyOutputMessage = new CachedBodyOutputMessage(exchange, httpHeaders);
                    log.info("2. GatewayContext Rewrite Form Data :{}", formDataBodyString);
                    return bodyInserter.insert(cachedBodyOutputMessage, new BodyInserterContext())
                            .then(Mono.defer(() -> {<!-- -->
                                ServerHttpRequestDecorator decorator = new ServerHttpRequestDecorator(
                                        exchange.getRequest()) {<!-- -->
                                    @Override
                                    public HttpHeaders getHeaders() {<!-- -->
                                        return httpHeaders;
                                    }

                                    @Override
                                    public Flux<DataBuffer> getBody() {<!-- -->
                                        return cachedBodyOutputMessage.getBody();
                                    }
                                };
                                return chain.filter(exchange.mutate().request(decorator).build());
                            }));
                }));
    }


    /**
     * ReadJsonBody
     */
    private Mono<Void> readBody(ServerWebExchange exchange, WebFilterChain chain, GatewayContext gatewayContext) {<!-- -->
        return DataBufferUtils.join(exchange.getRequest().getBody())
                .flatMap(dataBuffer -> {<!-- -->
                    /*
                     * read the body Flux<DataBuffer>, and release the buffer
                     * when SpringCloudGateway Version Release To G.SR2, this can be updated with the new version's feature
                     * see PR https://github.com/spring-cloud/spring-cloud-gateway/pull/1095
                     */
                    byte[] bytes = new byte[dataBuffer.readableByteCount()];
                    dataBuffer.read(bytes);
                    DataBufferUtils.release(dataBuffer);
                    gatewayContext.setRequestBodyBytes(bytes);
                    Flux<DataBuffer> cachedFlux = Flux.defer(() -> {<!-- -->
                        DataBuffer buffer = exchange.getResponse().bufferFactory().wrap(bytes);
                        DataBufferUtils.retain(buffer);
                        return Mono.just(buffer);
                    });
                    /*
                     * repackage ServerHttpRequest
                     */
                    ServerHttpRequest mutatedRequest = new ServerHttpRequestDecorator(exchange.getRequest()) {<!-- -->
                        @Override
                        public Flux<DataBuffer> getBody() {<!-- -->
                            return cachedFlux;
                        }
                    };
                    ServerWebExchange mutatedExchange = exchange.mutate().request(mutatedRequest).build();
                    return ServerRequest.create(mutatedExchange, MESSAGE_READERS)
                            .bodyToMono(String.class)
                            .doOnNext(objectValue -> {<!-- -->
                                gatewayContext.setRequestBody(objectValue);
                                if (objectValue != null & amp; & amp; !objectValue.trim().startsWith("{")) {<!-- -->
                                    return;
                                }
                                try {<!-- -->
                                    gatewayContext.getAllRequestData().setAll(JsonUtil.fromJson(objectValue, Map.class));
                                } catch (Exception e) {<!-- -->
                                    log.warn("Gateway context Read JsonBody error:{}", e.getMessage(), e);
                                }
                            }).then(chain.filter(mutatedExchange));
                });
    }

}

3. Signature, anti-replay verification
Here we can take the parameters from the context object
Signature algorithm logic:

@Slf4j
@Component
public class GatewaySignCheckFilter implements WebFilter, Ordered {<!-- -->


    @Value("${api.rest.prefix}")
    private String apiPrefix;

    @Autowired
    private RedisUtil redisUtil;

    //The front-end and back-end agree on the signature key
    private static final String API_SECRET = "secret-xxx";

    @Override
    public int getOrder() {<!-- -->
        return Integer.MIN_VALUE + 2;
    }

    @NotNull
    @Override
    public Mono<Void> filter(ServerWebExchange exchange, @NotNull WebFilterChain chain) {<!-- -->
        ServerHttpRequest request = exchange.getRequest();
        String uri = request.getURI().getPath();
        GatewayContext gatewayContext = (GatewayContext) exchange.getAttributes().get(GatewayContext.CACHE_GATEWAY_CONTEXT);
        HttpHeaders headers = gatewayContext.getRequestHeaders();
        MediaType contentType = headers.getContentType();
        log.info("check url:{},method:{},contentType:{}", uri, gatewayContext.getRequestMethod(), contentType == null ? "" : contentType.toString());
        //If contentType is empty, it can only be a get request
        if (contentType == null || StringUtils.isBlank(contentType.toString())) {<!-- -->
            if (request.getMethod() != HttpMethod.GET) {<!-- -->
                throw new RuntimeException("illegal access");
            }
            checkSign(uri, gatewayContext, exchange);
        } else {<!-- -->
            if (MediaType.APPLICATION_JSON.equals(contentType) || MediaType.APPLICATION_FORM_URLENCODED.equalsTypeAndSubtype(contentType)) {<!-- -->
                checkSign(uri, gatewayContext, exchange);
            }
        }

        return chain.filter(exchange);
    }


    private void checkSign(String uri, GatewayContext gatewayContext, ServerWebExchange exchange) {<!-- -->
        //Ignore the request
        List<String> ignores = Lists.newArrayList("/open/**", "/open/login/params", "/open/image");
        for (String ignore : ignores) {<!-- -->
            ignore = apiPrefix + ignore;
            if (uri.equals(ignore) || uri.startsWith(ignore.replace("/**", "/"))) {<!-- -->
                log.info("check sign ignore:{}", uri);
                return;
            }
        }
        String method = gatewayContext.getRequestMethod();
        log.info("start check sign {}-{}", method, uri);
        HttpHeaders headers = gatewayContext.getRequestHeaders();
        log.info("headers:{}", JsonUtils.objectToJson(headers));
        String clientId = getHeaderAttr(headers, SystemSign.CLIENT_ID);
        String timestamp = getHeaderAttr(headers, SystemSign.TIMESTAMP);
        String nonce = getHeaderAttr(headers, SystemSign.NONCE);
        String sign = getHeaderAttr(headers, SystemSign.SIGN);
        checkTime(timestamp);
        checkOnce(nonce);
        String headerStr = String.format("%s=%s & amp;%s=%s & amp;%s=%s", SystemSign.CLIENT_ID, clientId,
                SystemSign.NONCE, nonce, SystemSign.TIMESTAMP, timestamp);
        String signSecret = API_SECRET;
        String queryUri = uri + getQueryParam(gatewayContext.getQueryParams());
        log.info("headerStr:{},signSecret:{},queryUri:{}", headerStr, signSecret, queryUri);
        String realSign = calculatorSign(clientId, queryUri, gatewayContext, headerStr, signSecret);
        log.info("sign:{}, realSign:{}", sign, realSign);
        if (!realSign.equals(sign)) {<!-- -->
            log.warn("wrong sign");
            throw new RuntimeException("Illegal sign");
        }
    }

    private String getQueryParam(MultiValueMap<String, String> queryParams) {<!-- -->
        if (queryParams == null || queryParams.size() == 0) {<!-- -->
            return StringUtils.EMPTY;
        }
        StringBuilder builder = new StringBuilder("?");
        for (Map.Entry<String, List<String>> entry : queryParams.entrySet()) {<!-- -->
            String key = entry.getKey();
            List<String> value = entry.getValue();
            builder.append(key).append("=").append(value.get(0)).append(" & amp;");
        }
        builder.deleteCharAt(builder.length() - 1);
        return builder.toString();
    }

    private String getHeaderAttr(HttpHeaders headers, String key) {<!-- -->
        List<String> values = headers.get(key);
        if (CollectionUtils.isEmpty(values)) {<!-- -->
            log.warn("GatewaySignCheckFilter empty header:{}", key);
            throw new RuntimeException("GatewaySignCheckFilter empty header:" + key);
        }
        String value = values.get(0);
        if (StringUtils.isBlank(value)) {<!-- -->
            log.warn("GatewaySignCheckFilter empty header:{}", key);
            throw new RuntimeException("GatewaySignCheckFilter empty header:" + key);
        }
        return value;
    }


    private String calculatorSign(String clientId, String queryUri, GatewayContext gatewayContext, String headerStr, String signSecret) {<!-- -->
        String method = gatewayContext.getRequestMethod();
        byte[] bodyBytes = gatewayContext.getRequestBodyBytes();
        if (bodyBytes == null) {<!-- -->
            //The blank md5 is fixed to: d41d8cd98f00b204e9800998ecf8427e
            bodyBytes = new byte[]{<!-- -->};
        }
        String bodyMd5 = SignUtils.getMd5(bodyBytes);
        String ori = String.format("%s\\
%s\\
%s\\
%s\\
%s\\
", method, clientId, headerStr, queryUri, bodyMd5);
        log.info("clientId:{},signSecret:{},headerStr:{},bodyMd5:{},queryUri:{},ori:{}", clientId, signSecret, headerStr, bodyMd5, queryUri, ori) ;
        return SignUtils.sha256HMAC(ori, signSecret);
    }

    private void checkOnce(String nonce) {<!-- -->
        if (StringUtils.isBlank(nonce)) {<!-- -->
            log.warn("GatewaySignCheckFilter checkOnce Illegal");
        }
        String key = "api:auth:" + nonce;
        int fifteenMin = 60 * 15 * 1000;
        Boolean succ = redisUtil.setNxWithExpire(key, "1", fifteenMin);
        if (succ == null || !succ) {<!-- -->
            log.warn("GatewaySignCheckFilter checkOnce Repeat");
            throw new RuntimeException("checkOnce Repeat");
        }
    }


    private void checkTime(String timestamp) {<!-- -->
        long time;
        try {<!-- -->
            time = Long.parseLong(timestamp);
        } catch (Exception ex) {<!-- -->
            log.error("GatewaySignCheckFilter checkTime error:{}", ex.getMessage(), ex);
            throw new RuntimeException("checkTime error");
        }
        long now = DateTimeUtil.now();
        log.info("now: {}, time: {}", DateTimeUtil.millsToStr(now), DateTimeUtil.millsToStr(time));
        int fiveMinutes = 60 * 5 * 1000;
        long duration = now - time;
        if (duration > fiveMinutes || (-duration) > fiveMinutes) {<!-- -->
            log.warn("GatewaySignCheckFilter checkTime Late");
            throw new RuntimeException("checkTime Late");
        }
    }

    public interface SystemSign {<!-- -->
        /**
         * Client ID: fixed value, the backend issues the agreement to the frontend
         */
        String CLIENT_ID = "client-id";

        /**
         * Signature calculated by client
         */
        String SIGN = "sign";

        /**
         * timestamp
         */
        String TIMESTAMP = "timestamp";

        /**
         * unique value
         */
        String NONCE = "nonce";
    }

}

4. Signature tools

@Slf4j
public class SignUtils {<!-- -->
    /**
     * sha256_HMAC encryption
     *
     * @param message message
     * @param secret secret key
     * @return encrypted string
     */
    public static String sha256HMAC(String message, String secret) {<!-- -->
        try {<!-- -->
            Mac sha256_HMAC = Mac.getInstance("HmacSHA256");
            SecretKeySpec secret_key = new SecretKeySpec(secret.getBytes(StandardCharsets.UTF_8), "HmacSHA256");
            sha256_HMAC.init(secret_key);
            byte[] bytes = sha256_HMAC.doFinal(message.getBytes(StandardCharsets.UTF_8));
            return byteArrayToHexString(bytes);
        } catch (Exception e) {<!-- -->
            log.error("error in sha256HMAC:" + e.getMessage(), e);
            return StringUtils.EMPTY;
        }
    }

    /**
     * Convert the encrypted byte array into a string
     *
     * @param b byte array
     * @return string
     */
    public static String byteArrayToHexString(byte[] b) {<!-- -->
        StringBuilder hs = new StringBuilder();
        String stmp;
        for (int n = 0; b != null & amp; & amp; n < b.length; n + + ) {<!-- -->
            stmp = Integer.toHexString(b[n] & amp; 0XFF);
            if (stmp.length() == 1) {<!-- -->
                hs.append('0');
            }
            hs.append(stmp);
        }
        return hs.toString().toLowerCase();
    }

    public static String getMd5(String requestBody) {<!-- -->
        requestBody = Optional.ofNullable(requestBody).orElse(StringUtils.EMPTY);
        return DigestUtils.md5DigestAsHex(requestBody.getBytes(StandardCharsets.UTF_8));
    }

    public static String getMd5(byte[] requestBody) {<!-- -->
        return DigestUtils.md5DigestAsHex(requestBody).toLowerCase();
    }

}