aboutsummaryrefslogblamecommitdiff
path: root/juick-ws/src/main/java/com/juick/ws/WebsocketComponent.java
blob: bdf6dba6b7d2da6630d62acc94af42c937318dfb (plain) (tree)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17















                                                                           
                     
                         
                      
                                              
                                               
                                                  
                                         
                                             
                                     
                                            
                                                 
                               
                                            
                                                           
                                                  
                                                  
                                                  
                                                                   
                                                         
 
                                      
                           
                           
                                                   
                    
                             
                            
                      


                                       
                                                                                          
                                                                                           
 
                                                                                                    
 
           
                                    
           
                                            
           
                                                    
           
                                   
 



                                   

                                                                                       
                        





                                                                                         





                                                                                      
         
                                                                                                        

                                           
                                              
                                                               
                                                                                 
                                                                                                        

                                                                                
                                                                                                            

                                                                                
                                                                         
                          
                                                                           
                                                                                                                    




                                                                                        
                                                   



                                 
                                        
                                                                    
             
         
 

                                                                                                      
                                                                                                   
                                                                             
                                                                
         
 
     
 









                                                            





                                                                                                      
                                                                                           





                                                                                                                      


                                                








                                                                                           





                                                                                                   








                                                                                           
                            

















                                                                                                            
/*
 * Copyright (C) 2008-2017, Juick
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

package com.juick.ws;

import com.juick.Message;
import com.juick.User;
import com.juick.server.helpers.AnonymousUser;
import com.juick.server.protocol.JuickProtocol;
import com.juick.server.protocol.ProtocolListener;
import com.juick.service.MessagesService;
import com.juick.service.SubscriptionService;
import com.juick.service.UserService;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.math.NumberUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PingMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;

import javax.annotation.PostConstruct;
import javax.inject.Inject;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.net.URI;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;

/**
 * Created by vitalyster on 28.06.2016.
 */
public class WebsocketComponent extends TextWebSocketHandler implements ProtocolListener {
    private static final Logger logger = LoggerFactory.getLogger(WebsocketComponent.class);

    private final List<SocketSubscribed> clients = Collections.synchronizedList(new LinkedList<>());

    @Inject
    private UserService userService;
    @Inject
    private MessagesService messagesService;
    @Inject
    private SubscriptionService subscriptionService;
    @Inject
    private JuickProtocol protocol;

    @PostConstruct
    public void init() {
        protocol.setListener(this);
    }

    @Override
    public void afterConnectionEstablished(WebSocketSession session) throws Exception {
        URI hLocation;
        String hXRealIP;

        hLocation = session.getUri();
        HttpHeaders headers = session.getHandshakeHeaders();
        hXRealIP = headers.getOrDefault("X-Real-IP",
                Collections.singletonList(session.getRemoteAddress().toString())).get(0);

        // Auth
        User visitor = AnonymousUser.INSTANCE;
        UriComponents uriComponents = UriComponentsBuilder.fromUri(hLocation).build();
        List<String> hash = uriComponents.getQueryParams().get("hash");
        if (hash != null && hash.get(0).length() == 16) {
            visitor = userService.getUserByHash(hash.get(0));
        } else {
            logger.info("wrong hash for {} from {}", visitor.getUid(), hXRealIP);
        }
        logger.info("user {} connected to {} from {}", visitor.getUid(), hLocation.getPath(), hXRealIP);

        int MID = 0;
        SocketSubscribed sockSubscr = null;
        if (hLocation.getPath().equals("/")) {
            logger.info("user {} connected", visitor.getUid());
            sockSubscr = new SocketSubscribed(session, hXRealIP, visitor, false);
        } else if (hLocation.getPath().equals("/_all")) {
            logger.info("user {} connected to legacy _all ({})", visitor.getUid(), hLocation.getPath());
            sockSubscr = new SocketSubscribed(session, hXRealIP, visitor, true);
            sockSubscr.allMessages = true;
        } else if (hLocation.getPath().equals("/_replies")) {
            logger.info("user {} connected to legacy _replies ({})", visitor.getUid(), hLocation.getPath());
            sockSubscr = new SocketSubscribed(session, hXRealIP, visitor, true);
            sockSubscr.allReplies = true;
        } else if (hLocation.getPath().matches("/\\d+$")) {
            MID = NumberUtils.toInt(hLocation.getPath().substring(1), 0);
            if (MID > 0) {
                if (messagesService.canViewThread(MID, visitor.getUid())) {
                    logger.info("user {} connected to legacy thread ({}) from {}", visitor.getUid(), MID, hXRealIP);
                    sockSubscr = new SocketSubscribed(session, hXRealIP, visitor, true);
                    sockSubscr.MID = MID;
                } else {
                    try {
                        session.close(new CloseStatus(403, "Forbidden"));
                    } catch (IOException e) {
                        logger.warn("ws error", e);
                    }
                }
            }
        }
        if (sockSubscr != null) {
            synchronized (clients) {
                clients.add(sockSubscr);
                logger.info("{} clients connected", clients.size());
            }
        }
    }

    @Override
    public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
        synchronized (clients) {
            logger.info("session closed with status {}: {}", status.getCode(), status.getReason());
            clients.removeIf(c -> c.session.getId().equals(session.getId()));
            logger.info("{} clients connected", clients.size());
        }

    }

    @Scheduled(fixedRate = 30000)
    public void ping() {
        clients.forEach(c -> {
            try {
                c.session.sendMessage(new PingMessage());
            } catch (IOException e) {
                logger.error("WebSocket PING exception", e);
            }
        });
    }

    @Override
    protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
        getClients().stream().filter(s -> !s.legacy && s.session.equals(session)).forEach(s -> {
            User user = s.visitor;
            String input = message.getPayload();
            if (StringUtils.isNotBlank(input)) {
                try {
                    s.session.sendMessage(new TextMessage(protocol.getReply(user, input)));
                } catch (IOException | InvocationTargetException | IllegalAccessException | NoSuchMethodException e) {
                    logger.error("protocol exception", e);
                }
            }
        });
    }

    public List<SocketSubscribed> getClients() {
        return clients;
    }

    @Override
    public void privateMessage(User from, User to, String body) {
        notifyUser(from, to, "Private message from @" + from.getName() + ":\n" + body);
    }

    @Override
    public void userSubscribed(User from, User to) {
        notifyUser(from, to, String.format("@%s subscribed to your blog", from.getName()));
    }

    @Override
    public void messagePosted(Message msg) {
        subscriptionService.getSubscribedUsers(msg.getUser().getUid(), msg.getMid()).forEach(u -> {
            notifyUser(msg.getUser(), u, msg.getText());
        });
    }

    private void notifyUser(User from, User to, String body) {
        getClients().stream().filter(s -> !s.legacy && s.visitor.equals(to)).forEach(s -> {
            try {
                s.session.sendMessage(new TextMessage(body));
            } catch (IOException e) {
                logger.error("protocol exception", e);
            }
        });
    }

    class SocketSubscribed {
        WebSocketSession session;
        String clientName;
        User visitor;
        int MID;
        boolean allMessages;
        boolean allReplies;
        long tsConnected;
        long tsLastData;
        boolean legacy;

        public SocketSubscribed(WebSocketSession session, String clientName, User visitor, boolean legacy) {
            this.session = session;
            this.clientName = clientName;
            this.visitor = visitor;
            tsConnected = tsLastData = System.currentTimeMillis();
            this.legacy = legacy;
        }
    }
}