/*
* Copyright (C) 2008-2017, Juick
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*/
package com.juick.server;
import com.juick.server.xmpp.s2s.*;
import com.juick.service.UserService;
import com.juick.xmpp.extensions.StreamError;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;
import org.xmlpull.v1.XmlPullParserException;
import rocks.xmpp.addr.Jid;
import rocks.xmpp.core.stanza.model.Stanza;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import javax.inject.Inject;
import javax.net.ssl.*;
import java.security.InvalidAlgorithmParameterException;
import java.security.cert.*;
import javax.xml.bind.JAXBException;
import javax.xml.bind.Unmarshaller;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.StringReader;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketException;
import java.security.KeyStore;
import java.security.KeyStoreException;
import java.security.SecureRandom;
import java.time.Duration;
import java.time.Instant;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* @author ugnich
*/
@Component
public class XMPPServer implements ConnectionListener, AutoCloseable {
private static final Logger logger = LoggerFactory.getLogger("com.juick.server.xmpp");
private static final int TIMEOUT_MINUTES = 15;
@Inject
public ExecutorService service;
@Value("${hostname:localhost}")
private Jid jid;
@Value("${s2s_port:5269}")
private int s2sPort;
@Value("${keystore:juick.p12}")
public String keystore;
@Value("${keystore_password:secret}")
public String keystorePassword;
@Value("${broken_ssl_hosts:}")
public String[] brokenSSLhosts;
@Value("${banned_hosts:}")
public String[] bannedHosts;
private final List inConnections = new CopyOnWriteArrayList<>();
private final Map> outConnections = new ConcurrentHashMap<>();
private final List outCache = new CopyOnWriteArrayList<>();
private final List stanzaListeners = new CopyOnWriteArrayList<>();
private final AtomicBoolean closeFlag = new AtomicBoolean(false);
SSLContext sc;
CertificateFactory cf;
CertPathValidator cpv;
PKIXParameters params;
private TrustManager[] trustAllCerts = new TrustManager[]{
new X509TrustManager() {
public void checkClientTrusted(java.security.cert.X509Certificate[] certs, String authType) {
}
public void checkServerTrusted(java.security.cert.X509Certificate[] certs, String authType) {
}
public java.security.cert.X509Certificate[] getAcceptedIssuers() {
return new X509Certificate[0];
}
}
};
private boolean tlsConfigured = false;
private ServerSocket listener;
@Inject
private BasicXmppSession session;
@Inject
private UserService userService;
@PostConstruct
public void init() throws KeyStoreException {
closeFlag.set(false);
KeyStore ks = KeyStore.getInstance("PKCS12");
try (InputStream ksIs = new FileInputStream(keystore)) {
ks.load(ksIs, keystorePassword.toCharArray());
KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory
.getDefaultAlgorithm());
kmf.init(ks, keystorePassword.toCharArray());
sc = SSLContext.getInstance("TLSv1.2");
sc.init(kmf.getKeyManagers(), trustAllCerts, new SecureRandom());
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
Set ca = new HashSet<>();
trustManagerFactory.init((KeyStore)null);
Arrays.stream(trustManagerFactory.getTrustManagers()).forEach(t -> Arrays.stream(((X509TrustManager)t).getAcceptedIssuers()).forEach(cert -> ca.add(new TrustAnchor(cert, null))));
params = new PKIXParameters(ca);
params.setRevocationEnabled(false);
cpv = CertPathValidator.getInstance("PKIX");
cf = CertificateFactory.getInstance( "X.509" );
tlsConfigured = true;
} catch (Exception e) {
logger.warn("tls unavailable");
}
service.submit(() -> {
try {
listener = new ServerSocket(s2sPort);
logger.info("s2s listener ready");
while (!listener.isClosed()) {
if (Thread.currentThread().isInterrupted()) break;
Socket socket = listener.accept();
ConnectionIn client = new ConnectionIn(this, socket);
addConnectionIn(client);
service.submit(client);
}
} catch (SocketException e) {
// shutdown
} catch (IOException | XmlPullParserException e) {
logger.warn("xmpp exception", e);
}
});
}
@Override
public void close() throws Exception {
if (listener != null && !listener.isClosed()) {
listener.close();
}
outConnections.forEach((c, s) -> {
c.logoff();
outConnections.remove(c);
});
inConnections.forEach(c -> {
c.closeConnection();
inConnections.remove(c);
});
service.shutdown();
logger.info("XMPP server destroyed");
}
public void addConnectionIn(ConnectionIn c) {
c.setListener(this);
inConnections.add(c);
}
public void addConnectionOut(ConnectionOut c, Optional socket) {
c.setListener(this);
outConnections.put(c, socket);
}
public void removeConnectionIn(ConnectionIn c) {
inConnections.remove(c);
}
public void removeConnectionOut(ConnectionOut c) {
outConnections.remove(c);
}
public String getFromCache(Jid to) {
final String[] cache = new String[1];
outCache.stream().filter(c -> c.hostname != null && c.hostname.equals(to)).findFirst().ifPresent(c -> {
cache[0] = c.xml;
outCache.remove(c);
});
return cache[0];
}
public Optional getConnectionOut(Jid hostname, boolean needReady) {
return outConnections.keySet().stream().filter(c -> c.to != null &&
c.to.equals(hostname) && (!needReady || c.streamReady)).findFirst();
}
public Optional getConnectionIn(String streamID) {
return inConnections.stream().filter(c -> c.streamID != null && c.streamID.equals(streamID)).findFirst();
}
public void sendOut(Jid hostname, String xml) {
boolean haveAnyConn = false;
ConnectionOut connOut = null;
for (ConnectionOut c : outConnections.keySet()) {
if (c.to != null && c.to.equals(hostname)) {
if (c.streamReady) {
connOut = c;
break;
} else {
haveAnyConn = true;
break;
}
}
}
if (connOut != null) {
connOut.send(xml);
return;
}
boolean haveCache = false;
for (CacheEntry c : outCache) {
if (c.hostname != null && c.hostname.equals(hostname)) {
c.xml += xml;
c.updated = Instant.now();
haveCache = true;
break;
}
}
if (!haveCache) {
outCache.add(new CacheEntry(hostname, xml));
}
if (!haveAnyConn && !closeFlag.get()) {
try {
createDialbackConnection(hostname.toEscapedString(), null, null);
} catch (Exception e) {
logger.warn("dialback error", e);
}
}
}
void createDialbackConnection(String to, String checkSID, String dbKey) throws Exception {
ConnectionOut connectionOut = new ConnectionOut(getJid(), Jid.of(to), null, null, checkSID, dbKey);
addConnectionOut(connectionOut, Optional.empty());
service.submit(() -> {
try {
Socket socket = new Socket();
socket.connect(DNSQueries.getServerAddress(to));
connectionOut.setInputStream(socket.getInputStream());
connectionOut.setOutputStream(socket.getOutputStream());
addConnectionOut(connectionOut, Optional.of(socket));
connectionOut.connect();
} catch (IOException e) {
logger.info("dialback to " + to + " exception", e);
}
});
}
public void startDialback(Jid from, String streamId, String dbKey) throws Exception {
Optional c = getConnectionOut(from, false);
if (c.isPresent()) {
c.get().sendDialbackVerify(streamId, dbKey);
} else {
createDialbackConnection(from.toEscapedString(), streamId, dbKey);
}
}
public void addStanzaListener(StanzaListener listener) {
stanzaListeners.add(listener);
}
public void onStanzaReceived(String xmlValue) {
logger.info("S2S: {}", xmlValue);
Stanza stanza = parse(xmlValue);
stanzaListeners.forEach(l -> l.stanzaReceived(stanza));
}
public BasicXmppSession getSession() {
return session;
}
public List getInConnections() {
return inConnections;
}
public Map> getOutConnections() {
return outConnections;
}
@Override
public boolean isTlsAvailable() {
return tlsConfigured;
}
@Override
public void starttls(ConnectionIn connection) {
logger.debug("stream {} securing", connection.streamID);
connection.sendStanza("");
try {
connection.setSocket(sc.getSocketFactory().createSocket(connection.getSocket(), connection.getSocket().getInetAddress().getHostAddress(),
connection.getSocket().getPort(), false));
SSLSocket sslSocket = (SSLSocket) connection.getSocket();
sslSocket.addHandshakeCompletedListener(handshakeCompletedEvent -> {
try {
CertPath certPath = cf.generateCertPath(Arrays.asList(handshakeCompletedEvent.getPeerCertificates()));
cpv.validate(certPath, params);
connection.setTrusted(true);
logger.info("connection from {} is trusted", connection.from);
} catch (SSLPeerUnverifiedException | CertificateException | CertPathValidatorException | InvalidAlgorithmParameterException e) {
logger.info("connection from {} is NOT trusted, falling back to dialback", connection.from);
}
});
sslSocket.setUseClientMode(false);
sslSocket.setNeedClientAuth(true);
sslSocket.startHandshake();
connection.setSecured(true);
logger.debug("stream from {} secured", connection.streamID);
connection.restartParser();
} catch (XmlPullParserException | IOException sex) {
logger.warn("stream {} ssl error {}", connection.streamID, sex);
connection.sendStanza("");
removeConnectionIn(connection);
connection.closeConnection();
}
}
@Override
public void proceed(ConnectionOut connection) {
try {
Socket socket = outConnections.get(connection).get();
socket = sc.getSocketFactory().createSocket(socket, socket.getInetAddress().getHostAddress(),
socket.getPort(), false);
SSLSocket sslSocket = (SSLSocket) socket;
sslSocket.addHandshakeCompletedListener(handshakeCompletedEvent -> {
try {
CertPath certPath = cf.generateCertPath(Arrays.asList(handshakeCompletedEvent.getPeerCertificates()));
cpv.validate(certPath, params);
connection.setTrusted(true);
logger.info("connection to {} is trusted", connection.to);
} catch (SSLPeerUnverifiedException | CertificateException | CertPathValidatorException | InvalidAlgorithmParameterException e) {
logger.info("connection to {} is NOT trusted, falling back to dialback", connection.to);
}
});
sslSocket.setNeedClientAuth(true);
sslSocket.startHandshake();
connection.setSecured(true);
logger.debug("stream to {} secured", connection.getStreamID());
connection.setInputStream(socket.getInputStream());
connection.setOutputStream(socket.getOutputStream());
connection.restartStream();
connection.sendOpenStream();
} catch (NoSuchElementException | XmlPullParserException | IOException sex) {
logger.error("s2s ssl error: {} {}, error {}", connection.to, connection.getStreamID(), sex);
connection.send("");
removeConnectionOut(connection);
connection.logoff();
}
}
@Override
public void verify(ConnectionOut connection, String from, String type, String sid) {
if (from != null && from.equals(connection.to.toEscapedString()) && sid != null && !sid.isEmpty() && type != null) {
getConnectionIn(sid).ifPresent(c -> c.sendDialbackResult(Jid.of(from), type));
}
}
@Override
public void dialbackError(ConnectionOut connection, StreamError error) {
logger.warn("Stream error from {}: {}", connection.getStreamID(), error.getCondition());
removeConnectionOut(connection);
connection.logoff();
}
@Override
public void finished(ConnectionOut connection, boolean dirty) {
logger.warn("stream to {} {} finished, dirty={}", connection.to, connection.getStreamID(), dirty);
removeConnectionOut(connection);
connection.logoff();
}
@Override
public void exception(ConnectionOut connection, Exception ex) {
logger.error("s2s out exception: {} {}, exception {}", connection.to, connection.getStreamID(), ex);
removeConnectionOut(connection);
connection.logoff();
}
@Override
public void ready(ConnectionOut connection) {
logger.debug("stream to {} {} ready", connection.to, connection.getStreamID());
String cache = getFromCache(connection.to);
if (cache != null) {
logger.debug("stream to {} {} sending cache", connection.to, connection.getStreamID());
connection.send(cache);
}
}
@Override
public boolean securing(ConnectionOut connection) {
return tlsConfigured && !Arrays.asList(brokenSSLhosts).contains(connection.to.toEscapedString());
}
public Stanza parse(String xml) {
try {
Unmarshaller unmarshaller = session.createUnmarshaller();
return (Stanza)unmarshaller.unmarshal(new StringReader(xml));
} catch (JAXBException e) {
logger.error("JAXB exception", e);
}
return null;
}
public Jid getJid() {
return jid;
}
@Scheduled(fixedDelay = 10000)
public void cleanUp() {
Instant now = Instant.now();
outConnections.keySet().stream().filter(c -> Duration.between(c.getUpdated(), now).toMinutes() > TIMEOUT_MINUTES)
.forEach(c -> {
logger.info("closing idle outgoing connection to {}", c.to);
c.logoff();
outConnections.remove(c);
});
inConnections.stream().filter(c -> Duration.between(c.updated, now).toMinutes() > TIMEOUT_MINUTES)
.forEach(c -> {
logger.info("closing idle incoming connection from {}", c.from);
c.closeConnection();
inConnections.remove(c);
});
}
@PreDestroy
public void preDestroy() {
closeFlag.set(true);
}
}