package com.juick.ws; import com.juick.Message; import com.juick.User; import com.juick.server.protocol.JuickProtocol; import com.juick.server.protocol.ProtocolListener; import com.juick.service.MessagesService; import com.juick.service.SubscriptionService; import com.juick.service.UserService; import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.math.NumberUtils; import org.apache.http.NameValuePair; import org.apache.http.client.utils.URLEncodedUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.http.HttpHeaders; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.PingMessage; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.TextWebSocketHandler; import javax.annotation.PostConstruct; import javax.inject.Inject; import java.io.IOException; import java.lang.reflect.InvocationTargetException; import java.net.URI; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; import java.util.List; /** * Created by vitalyster on 28.06.2016. */ public class WebsocketComponent extends TextWebSocketHandler implements ProtocolListener { private static final Logger logger = LoggerFactory.getLogger(WebsocketComponent.class); private final List clients = Collections.synchronizedList(new ArrayList()); @Inject UserService userService; @Inject MessagesService messagesService; @Inject SubscriptionService subscriptionService; @Inject JuickProtocol protocol; @PostConstruct public void init() { protocol.setListener(this); } @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { URI hLocation; String hXRealIP; hLocation = session.getUri(); HttpHeaders headers = session.getHandshakeHeaders(); hXRealIP = headers.getOrDefault("X-Real-IP", Collections.singletonList(session.getRemoteAddress().toString())).get(0); // Auth User visitor = new User(); List params = URLEncodedUtils.parse(hLocation, StandardCharsets.UTF_8); for (NameValuePair param : params) { if (param.getName().equals("hash")) { String hash = param.getValue(); if (hash.length() == 16) { visitor = userService.getUserByHash(hash); } else { logger.info("wrong hash for {} from {}", visitor.getUid(), hXRealIP); } break; } } logger.info("user {} connected to {} from {}", visitor.getUid(), hLocation.getPath(), hXRealIP); int MID = 0; SocketSubscribed sockSubscr = null; if (hLocation.getPath().equals("/")) { logger.info("user {} connected", visitor.getUid()); sockSubscr = new SocketSubscribed(session, hXRealIP, visitor, false); } else if (hLocation.getPath().equals("/_all")) { logger.info("user {} connected to legacy _all ({})", visitor.getUid(), hLocation.getPath()); sockSubscr = new SocketSubscribed(session, hXRealIP, visitor, true); sockSubscr.allMessages = true; } else if (hLocation.getPath().equals("/_replies")) { logger.info("user {} connected to legacy _replies ({})", visitor.getUid(), hLocation.getPath()); sockSubscr = new SocketSubscribed(session, hXRealIP, visitor, true); sockSubscr.allReplies = true; } else if (hLocation.getPath().matches("/\\d+$")) { MID = NumberUtils.toInt(hLocation.getPath().substring(1), 0); if (MID > 0) { if (messagesService.canViewThread(MID, visitor.getUid())) { logger.info("user {} connected to legacy thread ({}) from {}", visitor.getUid(), MID, hXRealIP); sockSubscr = new SocketSubscribed(session, hXRealIP, visitor, true); sockSubscr.MID = MID; } else { try { session.close(new CloseStatus(403, "Forbidden")); } catch (IOException e) { logger.warn("ws error", e); } } } } if (sockSubscr != null) { synchronized (clients) { clients.add(sockSubscr); logger.info("{} clients connected", clients.size()); } } } @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { synchronized (clients) { logger.info("session closed with status {}: {}", status.getCode(), status.getReason()); clients.removeIf(c -> c.session.getId().equals(session.getId())); logger.info("{} clients connected", clients.size()); } } @Scheduled(fixedRate = 30000) public void ping() { clients.forEach(c -> { try { c.session.sendMessage(new PingMessage()); } catch (IOException e) { logger.error("WebSocket PING exception", e); } }); } @Override protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { getClients().stream().filter(s -> !s.legacy && s.session.equals(session)).forEach(s -> { User user = s.visitor; String input = message.getPayload(); if (StringUtils.isNotBlank(input)) { try { s.session.sendMessage(new TextMessage(protocol.getReply(user, input))); } catch (IOException | InvocationTargetException | IllegalAccessException | NoSuchMethodException e) { logger.error("protocol exception", e); } } }); } public List getClients() { return clients; } @Override public void privateMessage(User from, User to, String body) { notifyUser(from, to, "Private message from @" + from.getName() + ":\n" + body); } @Override public void userSubscribed(User from, User to) { notifyUser(from, to, String.format("@%s subscribed to your blog", from.getName())); } @Override public void messagePosted(Message msg) { subscriptionService.getSubscribedUsers(msg.getUser().getUid(), msg.getMid()).forEach(u -> { notifyUser(msg.getUser(), u, msg.getText()); }); } private void notifyUser(User from, User to, String body) { getClients().stream().filter(s -> !s.legacy && s.visitor.equals(to)).forEach(s -> { try { s.session.sendMessage(new TextMessage(body)); } catch (IOException e) { logger.error("protocol exception", e); } }); } class SocketSubscribed { WebSocketSession session; String clientName; User visitor; int MID; boolean allMessages; boolean allReplies; long tsConnected; long tsLastData; boolean legacy; public SocketSubscribed(WebSocketSession session, String clientName, User visitor, boolean legacy) { this.session = session; this.clientName = clientName; this.visitor = visitor; tsConnected = tsLastData = System.currentTimeMillis(); this.legacy = legacy; } } }