SpringSecurity uses WebSocket in combination with Stomp

Comparison with conventional use of WebSocket

When we use WebSocket in the springboot project, we only need to use simple annotations to build it. However, if SpringSecurity is integrated in the system, the traditional way of using annotations cannot well control the connection permissions of WebSocket.

Use with Stomp
Import dependencies
 <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-websocket</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.security</groupId>
            <artifactId>spring-security-messaging</artifactId>
        </dependency>
Configuration class

The configuration class exposes the relevant endpoints of WebSocket to the outside world and controls the role permissions of the connection.

@Configuration
@EnableWebSocketMessageBroker
public class WebSocketBrokerConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer {<!-- -->
    private static final String[] PERMISSIONS_LIST = {<!-- -->"ROOT", "USER"};
    
    private static final String ENDPOINTS="/stomp/webSocket"
    
    @Resource(name = "webSocketInboundThreadPool")
    private ThreadPoolTaskExecutor webSocketInboundThreadPool;
    
    @Resource
    private TokenService tokenService;
    
    @Resource
    private RedisTemplate<String, Object> redisTemplate;


    /**
     * Create WebSocket message error handler
     */
    @Bean
    public StompSubProtocolErrorHandler stompSubProtocolErrorHandler() {<!-- -->
        return new WebSocketErrorHandler();
    }

    /**
     * Create message inbound interceptor
     */
    @Bean
    public ChannelInterceptor inboundInterceptor() {<!-- -->
        return new WebSocketInboundInterceptor(constant, tokenService, redisTemplate);
    }

    @Override
    public void registerStompEndpoints(StompEndpointRegistry registry) {<!-- -->
        registry
                //Configure error handler
                .setErrorHandler(stompSubProtocolErrorHandler())
                //Configure WebSocket external connection site
                .addEndpoint(constant.getSocketEndpoints())
                //Allow cross-domain
                .setAllowedOrigins("*")
                //Use SocketJS
                .withSockJS();


    }

    @Override
    public void configureMessageBroker(MessageBrokerRegistry registry) {<!-- -->
        //Configure client subscription address prefix
        registry.enableSimpleBroker(constant.getSocketClientSubscribePrefix(), constant.getSocketClientResponseSubscribePrefix(), constant.getSocketCommonSubPre());
        //Configure the client to send message prefix
        registry.setApplicationDestinationPrefixes(constant.getSocketClientSendPrefix());
        //Configure the client to send message prefix point-to-point
        registry.setUserDestinationPrefix(constant.getSocketClientUserPrefix());


    }

    @Override
    protected void customizeClientInboundChannel(ChannelRegistration registration) {<!-- -->
        registration
                //Configure message inbound interceptor
                .interceptors(inboundInterceptor())
                //Configure the inbound channel thread pool
                .taskExecutor(webSocketInboundThreadPool);
    }

    /**
     * Allow cross-domain access
     */
    @Override
    protected boolean sameOriginDisabled() {<!-- -->
        return true;
    }

    @Override
    protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {<!-- -->
        messages
                //Messages without any headers are allowed access
                .nullDestMatcher()
                .permitAll()
                //Anyone can connect, leaving it to the interceptor to verify whether the connection can be allowed. If there are restrictions here, the request will not reach the interceptor at all.
                .simpTypeMatchers(SimpMessageType.CONNECT)
                .permitAll()
                .simpDestMatchers(constant.getSocketEndpoints())
                .permitAll()
                //Anyone can disconnect
                .simpTypeMatchers(SimpMessageType.DISCONNECT)
                .permitAll()
                //Only people with specified roles can subscribe
                .simpTypeMatchers(SimpMessageType.SUBSCRIBE)
                .hasAnyRole(PERMISSIONS_LIST)
                //Only specified roles can send messages
                .simpTypeMatchers(SimpMessageType.MESSAGE)
                .hasAnyRole(PERMISSIONS_LIST);
    }
}

Inbound message interceptor

The inbound message interceptor authenticates the user’s WebSocket connection request. Here, Token is used for verification, and Redis is used to limit the flow of socket messages.

public class WebSocketInboundInterceptor implements ChannelInterceptor {<!-- -->
    private static final Logger log = LoggerFactory.getLogger(WebSocketInboundInterceptor.class);

    private static final long EXPIRE = 5;
    private static final int MAX_ACCESS = 100;
    private final Constant constant;
    private final TokenService tokenService;
    private final RedisTemplate<String, Object> redisTemplate;

    public WebSocketInboundInterceptor(Constant constant, TokenService tokenService, RedisTemplate<String, Object> redisTemplate) {<!-- -->
        this.constant = constant;
        this.tokenService = tokenService;
        this.redisTemplate = redisTemplate;
    }


    @Override
    public Message<?> preSend(Message<?> message, MessageChannel channel) {<!-- -->
        StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
        assert accessor != null;
        //Get the token header passed during connection
        String token = accessor.getFirstNativeHeader(constant.getAccessTokenKey());
        //Control user permissions when the client performs connection operations
        if (Objects.equals(accessor.getCommand(), StompCommand.CONNECT)) {<!-- -->
            if (StringUtils.isBlank(token)) {<!-- -->
                throw new CustomException(ExceptionEnum.TOKEN_IS_BLANK);
            }
            TokenAuthentication tokenAuthentication = createAuthentication(token);
            //Set the user's information into the WebSocket connection Session, which can be commonly used in subsequent connection success events and disconnection events.
            accessor.setUser(tokenAuthentication);
        }
        //If it is a subscription request
        else if (Objects.equals(accessor.getCommand(), StompCommand.SUBSCRIBE)) {<!-- -->
            //If it is a subscription, determine whether the subscribed address is your own ID number. If not, do not allow subscription
            String subscribePath = accessor.getDestination();
            assert subscribePath != null;
            subscribePath = subscribePath.substring(subscribePath.lastIndexOf("/") + 1);
            String userId = String.valueOf(SecurityUtil.getCurrentUserId());
            if (!Objects.equals(userId, subscribePath)) {<!-- -->
                throw new CustomException(ExceptionEnum.NOT_ALLOW_SUBSCRIBE);
            }
        }
        //If it is a request to send a message
        else if (Objects.equals(accessor.getCommand(), StompCommand.SEND)) {<!-- -->
            Principal user = accessor.getUser();
            assert user != null;
            String name = user.getName();
            allowUserAccess(name);
        }
        return message;
    }


    /**
     * Construct a security object through the token information passed by the client. The specific Token verification method can be implemented by yourself
     *
     * @param token client token
     * @return verified security object
     */
    private TokenAuthentication createAuthentication(String token) {<!-- -->
        TokenAuthentication authentication = new TokenAuthentication(token, constant.getAccessTokenPrefix());
        //If the user's Token verification fails here, an exception will be thrown. This exception will be fed back to the front end, and this connection will not be successful.
        return (TokenAuthentication) tokenService.authenticate(authentication);
    }

    /**
     * For WebSocket current limiting scheme, if the maximum current limiting threshold specified by the system is exceeded, messages are not allowed to be sent. The limit is no more than ten messages within five seconds.
     */
    private void allowUserAccess(String username) {<!-- -->
        String key = constant.getSocketFlowLimitPrefix() + username;
        long l = execLua(key);
        if (l > MAX_ACCESS) {<!-- -->
            log.error("WebSocket message sending has reached the maximum flow limit, user: {} messages are prohibited from being sent", username);
            throw new CustomException(ExceptionEnum.SYSTEM_ERROR);
        }
    }

    /**
     * Use Lua scripts to perform atomic Redis operations,
     * If the key does not exist, set the value to 1 and set the expiration time to 5 seconds.
     * Accumulate if key exists. To avoid multi-thread concurrency, the key will never expire due to the failure to set the expiration time due to the key being modified.
     *
     * @return If there is no key, return 1, if there is a key, return the accumulated value
     */
    private long execLua(String key) {<!-- -->
        RedisScript<Long> script = new DefaultRedisScript<>("if redis.call('exists', KEYS[1]) == 0 then\
" +
                " redis.call('set', KEYS[1], 1, 'ex', " + EXPIRE + ")\
" +
                " return 1\
" +
                "else\
" +
                " return redis.call('incr', KEYS[1])\
" +
                "end", Long.class);
        Long result = redisTemplate.execute(script, Collections.singletonList(key));

        return Objects.isNull(result) ? 0 : result;
    }
}
Error handler

The error handler is for error handling from the client to the server. The system provides error handling by default. However, because the error information returned does not meet my needs, I modified the response by imitating the system’s error handler.

public class WebSocketErrorHandler extends StompSubProtocolErrorHandler {<!-- -->

    private static final byte[] EMPTY_PAYLOAD = new byte[0];

    /**
     * Handle client-to-server message errors
     *
     * @param clientMessage the client message related to the error, possibly
     * {@code null} if error occurred while parsing a WebSocket message
     * @param ex the cause for the error, never {@code null}
     */
    @Override
    public Message<byte[]> handleClientMessageProcessingError(Message<byte[]> clientMessage, Throwable ex) {<!-- -->
        StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.ERROR);
        accessor.setMessage(ex.getMessage());
        accessor.setLeaveMutable(true);

        StompHeaderAccessor clientHeaderAccessor;
        if (clientMessage != null) {<!-- -->
            clientHeaderAccessor = MessageHeaderAccessor.getAccessor(clientMessage, StompHeaderAccessor.class);
            if (clientHeaderAccessor != null) {<!-- -->
                String receiptId = clientHeaderAccessor.getReceipt();
                if (receiptId != null) {<!-- -->
                    accessor.setReceiptId(receiptId);
                }
            }
        }

        return handleInternal(accessor, EMPTY_PAYLOAD, null, null);
    }

    /**
     * Handle errors that occur when the server sends messages to the client
     *
     * @param errorMessage the error message, never {@code null}
     */
    @Override
    public Message<byte[]> handleErrorMessageToClient(Message<byte[]> errorMessage) {<!-- -->
        StompHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(errorMessage, StompHeaderAccessor.class);
        Assert.notNull(accessor, "No StompHeaderAccessor");
        if (!accessor.isMutable()) {<!-- -->
            accessor = StompHeaderAccessor.wrap(errorMessage);
        }
        return handleInternal(accessor, EMPTY_PAYLOAD, null, null);
    }

}
Connect event handler

Usually we need to do some corresponding operations after the client connection is successful, such as updating the user’s online status, counting the number of people online, etc. Or provide some corresponding operation information after the user goes offline. In the event processing class, we generally use user information. This user information is the user information we set in the inbound message interceptor earlier.

@Service
public class WebSocketEventHandler {<!-- -->
private static final Logger log = LoggerFactory.getLogger(WebSocketEventHandler.class);

/**
 * This event will be monitored after the connection is successful. If an exception is thrown in the verification interceptor, it will not come here.
 *
 * @param sessionConnectEvent connection event
 */
@EventListener(SessionConnectedEvent.class)
public void connectHandler(SessionConnectedEvent sessionConnectEvent) {<!-- -->
    Principal user = sessionConnectEvent.getUser();
        log.info("User WebSocket connection successful, user information: {}", user.getName());
    }
}

/**
 * WebSocket disconnect event monitoring
 *
 * @param sessionDisconnectEvent disconnect event
 */
@EventListener(SessionDisconnectEvent.class)
public void disconnectHandler(SessionDisconnectEvent sessionDisconnectEvent) {<!-- -->

    Principal user = sessionDisconnectEvent.getUser();
        log.info("User disconnected WeoSocket connection, user information: {}", username);
    }
}

/**
 * WebSocket subscription event listening
 *
 * @param event subscription event
 */
@EventListener(SessionSubscribeEvent.class)
public void subscribeHandler(SessionSubscribeEvent event) {<!-- -->
    Principal user = event.getUser();
    assert user != null;
    String name = user.getName();
    log.info("User {} establish subscribe", name);
}

/**
 * WebSocket unsubscribe event listening
 *
 * @param event Unsubscribe event
 */
@EventListener(SessionUnsubscribeEvent.class)
public void unSubscribeHandler(SessionUnsubscribeEvent event) {<!-- -->
    Principal user = event.getUser();
    assert user != null;
    String name = user.getName();
    log.info("User {} cancel subscribe", name);
}
Socket message sending

Write the corresponding business class. When pushing messages, the premise is that the client has subscribed to the corresponding address. In the previous configuration class, we defined the prefix address for user subscription. At the same time, in this practice, we subscribe to its own ID based on each user (it is not mandatory to define the user ID here, it only needs to be able to Just identify the user, because if the message you push to the user is single-sent, it will definitely only be sent to the user and cannot be monitored by other users)

@Service
public class SocketService {<!-- -->

    private static final Logger log = LoggerFactory.getLogger(SocketService.class);

    //This class is officially provided by Spring for message sending and can be injected directly
    @Resource
    private SimpMessagingTemplate simpMessagingTemplate;

    @Resource(name = "socketExecutor")
    private Executor executor;



    /**
    socketMessageDTO: The carrier of message data, which can be customized to the data type you want.
    userId: Which user needs to be sent to
    subscribePath: user subscription address prefix
    So the complete sending path is: /subscription-prefix/user-id
    */
    @Async(value = "socketExecutor")
    public void single(SocketMessageDTO<Object> socketMessageDTO, String userId, String subscribePath) {<!-- -->
        String destination = subscribePath + userId;
        Message<Object> message = buildMessage(JSON.toJSONBytes(socketMessageDTO));
        simpMessagingTemplate.send(destination, message);
    }

    /**
     * Build WebSocket message entity
     *
     * @param payload data carrier
     * @return message entity
     */
    public Message<Object> buildMessage(Object payload) {<!-- -->
        return MessageBuilder
                .withPayload(payload)
                .setHeader("content-type", "application/json").build();
    }
}

At this point, the related operations of using Stomp to complete WebSocket in Security are over. Spring’s official documentation also has more detailed instructions. The official website address: Web on Servlet Stack (spring.io)

Distributed WebSocket implementation ideas

Whether WebSocket uses traditional (bare WebSocket) or is encapsulated with Stomp, there is no way to persist the connected Session to Redis. This means that if in a distributed service, user A is connected to server1, but the specific business is implemented by Server2, then using Server2 to push messages to user A will not succeed. In order to solve this problem, my idea is as follows (if there is a better idea, I hope you can share it, because I also want to know)

1. Separately strip the Socket service

Separate the WebSocket service into a separate microservice, and forward it through Nginx. If it is a WebSocket-related request, it will be forwarded to this server. In this way, all WebSocket connections in the distribution are actually on one server. , while exposing the interface to the outside world. When the business is completed in other microservices, it can request the interface exposed by the socket service and push the corresponding information to the user. The disadvantage of this approach is also obvious, that is, the WebSocket service is a single application

2. Request forwarding

After each user connects successfully, record the server on which the connection is made (it can be stored in Redis according to the user’s ID and the resource name of the microservice in the registration center). When the business is completed, the message needs to be pushed to Before using the user, first check Redis to see if the user’s WebSocket connection is on this server. If so, just send it directly. If not, obtain the server information where the user is connected from Redis and then forward the request to the corresponding server (that is, each microservice needs to expose an interface to forward the requested message to the server connection. in the hands of the user). The disadvantage of this method is: assuming that the user’s Socket connection is on Server1 at this time, the business is completed on Server2, and it is found that the user’s connection is not on Server2, then the message is forwarded. During the forwarding process, the user disconnects on Server1. The connection is then quickly disconnected and reconnected. At this time, the load balancer may not necessarily connect it to Server1 again. It may be connected to Server2. Then the message is forwarded to Server1. After sending the message, it is found that the user connection is not here, and it is forwarded again. Return to Server2. In the case of relatively large network fluctuations, the user’s messages have been forwarded back and forth (assuming that the connection is Server1, Server2 alternates back and forth), which will cause message delays or over-receipt (if messages are stored in the database, the user If you get the lost message immediately after connecting, then the message is not sent, but it is stored in the database. At this time, you get the message that has not been sent. After a while, the message is sent, so you have received one more message). Of course, this is an extreme case.

3. Use message queue to solve

There are plug-ins in RabbitMQ that can integrate Stomp to implement WebSocket message push, but the client needs to directly connect to RabbitMQ. This method will be introduced in RabbitMQ. If you are interested, you can check the support and application of Stomp plug-in on RabbitMQ’s official website.

WebSocket cross-domain problem solving

Even though the cross-domain issue of Htp requests was configured, in actual application I found that this seemed useless for WebSocket requests, so I made separate cross-domain processing for WebSocket requests in Nginx. The specific configuration is as follows

if ($request_uri ~ "/webSocket/"){
            add_header Access-Control-Allow-Origin '';
            add_header Access-Control-Allow-Methods 'GET,POST,DELETE,OPTIONS,PUT';
            add_header Access-Control-Allow-Headers 'DNT,X-Mx-ReqToken,Keep-Alive,User-Agent,X-Requested-With,If-Modified-Since,Cache-Control,Content-Type,Authorization,accessToken,refreshToken ';
            add_header Access-Control-Allow-Credentials '';
            add_header Access-Control-Max-Age 86400;
    }