/*
* Copyright (C) 2008-2017, Juick
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package com.juick.ws;
import com.juick.Message;
import com.juick.User;
import com.juick.server.helpers.AnonymousUser;
import com.juick.server.protocol.JuickProtocol;
import com.juick.server.protocol.ProtocolListener;
import com.juick.service.MessagesService;
import com.juick.service.SubscriptionService;
import com.juick.service.UserService;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.math.NumberUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PingMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import javax.annotation.PostConstruct;
import javax.inject.Inject;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.net.URI;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
/**
* Created by vitalyster on 28.06.2016.
*/
public class WebsocketComponent extends TextWebSocketHandler implements ProtocolListener {
private static final Logger logger = LoggerFactory.getLogger(WebsocketComponent.class);
private final List<SocketSubscribed> clients = Collections.synchronizedList(new LinkedList<>());
@Inject
private UserService userService;
@Inject
private MessagesService messagesService;
@Inject
private SubscriptionService subscriptionService;
@Inject
private JuickProtocol protocol;
@PostConstruct
public void init() {
protocol.setListener(this);
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
URI hLocation;
String hXRealIP;
hLocation = session.getUri();
HttpHeaders headers = session.getHandshakeHeaders();
hXRealIP = headers.getOrDefault("X-Real-IP",
Collections.singletonList(session.getRemoteAddress().toString())).get(0);
// Auth
User visitor = AnonymousUser.INSTANCE;
UriComponents uriComponents = UriComponentsBuilder.fromUri(hLocation).build();
List<String> hash = uriComponents.getQueryParams().get("hash");
if (hash != null && hash.get(0).length() == 16) {
visitor = userService.getUserByHash(hash.get(0));
} else {
logger.info("wrong hash for {} from {}", visitor.getUid(), hXRealIP);
}
logger.info("user {} connected to {} from {}", visitor.getUid(), hLocation.getPath(), hXRealIP);
int MID = 0;
SocketSubscribed sockSubscr = null;
if (hLocation.getPath().equals("/")) {
logger.info("user {} connected", visitor.getUid());
sockSubscr = new SocketSubscribed(session, hXRealIP, visitor, false);
} else if (hLocation.getPath().equals("/_all")) {
logger.info("user {} connected to legacy _all ({})", visitor.getUid(), hLocation.getPath());
sockSubscr = new SocketSubscribed(session, hXRealIP, visitor, true);
sockSubscr.allMessages = true;
} else if (hLocation.getPath().equals("/_replies")) {
logger.info("user {} connected to legacy _replies ({})", visitor.getUid(), hLocation.getPath());
sockSubscr = new SocketSubscribed(session, hXRealIP, visitor, true);
sockSubscr.allReplies = true;
} else if (hLocation.getPath().matches("/\\d+$")) {
MID = NumberUtils.toInt(hLocation.getPath().substring(1), 0);
if (MID > 0) {
if (messagesService.canViewThread(MID, visitor.getUid())) {
logger.info("user {} connected to legacy thread ({}) from {}", visitor.getUid(), MID, hXRealIP);
sockSubscr = new SocketSubscribed(session, hXRealIP, visitor, true);
sockSubscr.MID = MID;
} else {
try {
session.close(new CloseStatus(403, "Forbidden"));
} catch (IOException e) {
logger.warn("ws error", e);
}
}
}
}
if (sockSubscr != null) {
synchronized (clients) {
clients.add(sockSubscr);
logger.info("{} clients connected", clients.size());
}
}
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
synchronized (clients) {
logger.info("session closed with status {}: {}", status.getCode(), status.getReason());
clients.removeIf(c -> c.session.getId().equals(session.getId()));
logger.info("{} clients connected", clients.size());
}
}
@Scheduled(fixedRate = 30000)
public void ping() {
clients.forEach(c -> {
try {
c.session.sendMessage(new PingMessage());
} catch (IOException e) {
logger.error("WebSocket PING exception", e);
}
});
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
getClients().stream().filter(s -> !s.legacy && s.session.equals(session)).forEach(s -> {
User user = s.visitor;
String input = message.getPayload();
if (StringUtils.isNotBlank(input)) {
try {
s.session.sendMessage(new TextMessage(protocol.getReply(user, input)));
} catch (IOException | InvocationTargetException | IllegalAccessException | NoSuchMethodException e) {
logger.error("protocol exception", e);
}
}
});
}
public List<SocketSubscribed> getClients() {
return clients;
}
@Override
public void privateMessage(User from, User to, String body) {
notifyUser(from, to, "Private message from @" + from.getName() + ":\n" + body);
}
@Override
public void userSubscribed(User from, User to) {
notifyUser(from, to, String.format("@%s subscribed to your blog", from.getName()));
}
@Override
public void messagePosted(Message msg) {
subscriptionService.getSubscribedUsers(msg.getUser().getUid(), msg.getMid()).forEach(u -> {
notifyUser(msg.getUser(), u, msg.getText());
});
}
private void notifyUser(User from, User to, String body) {
getClients().stream().filter(s -> !s.legacy && s.visitor.equals(to)).forEach(s -> {
try {
s.session.sendMessage(new TextMessage(body));
} catch (IOException e) {
logger.error("protocol exception", e);
}
});
}
class SocketSubscribed {
WebSocketSession session;
String clientName;
User visitor;
int MID;
boolean allMessages;
boolean allReplies;
long tsConnected;
long tsLastData;
boolean legacy;
public SocketSubscribed(WebSocketSession session, String clientName, User visitor, boolean legacy) {
this.session = session;
this.clientName = clientName;
this.visitor = visitor;
tsConnected = tsLastData = System.currentTimeMillis();
this.legacy = legacy;
}
}
}