/*
* Copyright (C) 2008-2017, Juick
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*/
package com.juick.server;
import com.juick.User;
import com.juick.model.AnonymousUser;
import com.juick.model.CommandResult;
import com.juick.server.util.HttpForbiddenException;
import com.juick.server.util.HttpNotFoundException;
import com.juick.service.MessagesService;
import com.juick.service.UserService;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.math.NumberUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.PingMessage;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
import javax.annotation.Nonnull;
import javax.inject.Inject;
import java.io.IOException;
import java.net.URI;
import java.time.Instant;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
/**
* Created by vitalyster on 28.06.2016.
*/
@Component
public class WebsocketManager extends TextWebSocketHandler {
private static final Logger logger = LoggerFactory.getLogger(WebsocketManager.class);
private final List clients = new CopyOnWriteArrayList<>();
@Inject
private UserService userService;
@Inject
private MessagesService messagesService;
@Inject
private CommandsManager commandsManager;
@Override
public void afterConnectionEstablished(WebSocketSession session) {
UserSession userSession = new UserSession(session);
URI hLocation = session.getUri();
// Auth
UriComponents uriComponents = UriComponentsBuilder.fromUri(hLocation).build();
List hash = uriComponents.getQueryParams().get("hash");
if (hash != null && hash.get(0).length() == 16) {
userSession.visitor = userService.getUserByHash(hash.get(0));
} else {
logger.debug("wrong hash for {} from {}", userSession.visitor.getUid(), userSession);
}
if (hLocation.getPath().equals("/ws/")) {
logger.debug("user {} connected", userSession.visitor.getUid());
} else if (hLocation.getPath().equals("/ws/_all")) {
logger.debug("user {} connected to legacy _all ({})", userSession.visitor.getUid(), hLocation.getPath());
userSession.legacy = true;
userSession.allMessages = true;
} else if (hLocation.getPath().equals("/ws/_replies")) {
logger.debug("user {} connected to legacy _replies ({})", userSession.visitor.getUid(), hLocation.getPath());
userSession.legacy = true;
userSession.allReplies = true;
} else if (hLocation.getPath().matches("^/ws/(\\d)+$")) {
int MID = NumberUtils.toInt(hLocation.getPath().substring(4), 0);
if (MID > 0) {
if (messagesService.canViewThread(MID, userSession.visitor.getUid())) {
logger.debug("user {} connected to legacy thread ({}) from {}", userSession.visitor.getUid(), MID, userSession);
userSession.legacy = true;
userSession.MID = MID;
} else {
throw new HttpForbiddenException();
}
}
} else {
throw new HttpNotFoundException();
}
clients.add(userSession);
logger.debug("{} clients connected", clients.size());
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) {
logger.debug("session closed with status {}: {}", status.getCode(), status.getReason());
clients.removeIf(c -> c.getDelegate().getId().equals(session.getId()));
logger.debug("{} clients connected", clients.size());
}
@Scheduled(fixedRate = 30000)
public void ping() {
clients.forEach(c -> {
try {
if (c.isOpen()) {
c.sendMessage(new PingMessage());
}
} catch (IOException e) {
logger.error("WebSocket PING exception", e);
}
});
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
UserSession ws = clients.stream().filter(c -> c.getDelegate().equals(session))
.findFirst().orElseThrow(IllegalStateException::new);
if (!ws.visitor.isAnonymous()) {
String command = message.getPayload().trim();
if (StringUtils.isNotEmpty(command)) {
CommandResult result = commandsManager.processCommand(ws.visitor, command, URI.create(""));
ws.sendMessage(new TextMessage(result.getText()));
}
} else {
ws.sendMessage(new TextMessage("Authorization required"));
}
}
public List getClients() {
return clients;
}
class UserSession extends ConcurrentWebSocketSessionDecorator {
User visitor;
int MID;
boolean allMessages;
boolean allReplies;
Instant tsConnected;
Instant tsLastData;
boolean legacy;
UserSession(WebSocketSession session) {
super(session, 60000, 65536);
this.visitor = AnonymousUser.INSTANCE;
tsConnected = tsLastData = Instant.now();
}
@Nonnull
@Override
public String toString() {
HttpHeaders headers = getHandshakeHeaders();
return headers.getOrDefault("X-Real-IP",
Collections.singletonList(getRemoteAddress().toString())).get(0);
}
}
}