aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/com/juick/server/XMPPServer.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/com/juick/server/XMPPServer.java')
-rw-r--r--src/main/java/com/juick/server/XMPPServer.java429
1 files changed, 429 insertions, 0 deletions
diff --git a/src/main/java/com/juick/server/XMPPServer.java b/src/main/java/com/juick/server/XMPPServer.java
new file mode 100644
index 00000000..86ab6a78
--- /dev/null
+++ b/src/main/java/com/juick/server/XMPPServer.java
@@ -0,0 +1,429 @@
+/*
+ * Copyright (C) 2008-2017, Juick
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as
+ * published by the Free Software Foundation, either version 3 of the
+ * License, or (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+package com.juick.server;
+
+import com.juick.server.xmpp.router.StreamError;
+import com.juick.server.xmpp.s2s.*;
+import com.juick.service.UserService;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.beans.factory.annotation.Value;
+import org.springframework.scheduling.annotation.Scheduled;
+import org.xmlpull.v1.XmlPullParserException;
+import rocks.xmpp.addr.Jid;
+import rocks.xmpp.core.stanza.model.Stanza;
+
+import javax.annotation.PostConstruct;
+import javax.annotation.PreDestroy;
+import javax.inject.Inject;
+import javax.net.ssl.*;
+import javax.xml.bind.JAXBException;
+import javax.xml.bind.Unmarshaller;
+import java.io.IOException;
+import java.io.StringReader;
+import java.net.ServerSocket;
+import java.net.Socket;
+import java.net.SocketException;
+import java.security.InvalidAlgorithmParameterException;
+import java.security.KeyStore;
+import java.security.KeyStoreException;
+import java.security.SecureRandom;
+import java.security.cert.*;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CopyOnWriteArrayList;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+/**
+ * @author ugnich
+ */
+public class XMPPServer implements ConnectionListener {
+ private static final Logger logger = LoggerFactory.getLogger("com.juick.server.xmpp");
+
+ private static final int TIMEOUT_MINUTES = 15;
+
+ @Inject
+ public ExecutorService service;
+ @Value("${hostname:localhost}")
+ private Jid jid;
+ @Value("${s2s_port:5269}")
+ private int s2sPort;
+ @Value("${broken_ssl_hosts:}")
+ public String[] brokenSSLhosts;
+ @Value("${banned_hosts:}")
+ public String[] bannedHosts;
+
+ private final List<ConnectionIn> inConnections = new CopyOnWriteArrayList<>();
+ private final Map<ConnectionOut, Optional<Socket>> outConnections = new ConcurrentHashMap<>();
+ private final List<CacheEntry> outCache = new CopyOnWriteArrayList<>();
+ private final List<StanzaListener> stanzaListeners = new CopyOnWriteArrayList<>();
+ private final AtomicBoolean closeFlag = new AtomicBoolean(false);
+
+ SSLContext sc;
+ CertificateFactory cf;
+ CertPathValidator cpv;
+ PKIXParameters params;
+ private TrustManager[] trustAllCerts = new TrustManager[]{
+ new X509TrustManager() {
+ public void checkClientTrusted(java.security.cert.X509Certificate[] certs, String authType) {
+ }
+
+ public void checkServerTrusted(java.security.cert.X509Certificate[] certs, String authType) {
+ }
+
+ public java.security.cert.X509Certificate[] getAcceptedIssuers() {
+ return new X509Certificate[0];
+ }
+ }
+ };
+ private boolean tlsConfigured = false;
+
+
+ private ServerSocket listener;
+
+ @Inject
+ private BasicXmppSession session;
+ @Inject
+ private UserService userService;
+ @Inject
+ private KeystoreManager keystoreManager;
+
+ @PostConstruct
+ public void init() throws KeyStoreException {
+ closeFlag.set(false);
+ try {
+ sc = SSLContext.getInstance("TLSv1.2");
+ sc.init(keystoreManager.getKeymanagerFactory().getKeyManagers(), trustAllCerts, new SecureRandom());
+ TrustManagerFactory trustManagerFactory =
+ TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+ Set<TrustAnchor> ca = new HashSet<>();
+ trustManagerFactory.init((KeyStore)null);
+ Arrays.stream(trustManagerFactory.getTrustManagers()).forEach(t -> Arrays.stream(((X509TrustManager)t).getAcceptedIssuers()).forEach(cert -> ca.add(new TrustAnchor(cert, null))));
+ params = new PKIXParameters(ca);
+ params.setRevocationEnabled(false);
+ cpv = CertPathValidator.getInstance("PKIX");
+ cf = CertificateFactory.getInstance( "X.509" );
+ tlsConfigured = true;
+ } catch (Exception e) {
+ logger.warn("tls unavailable");
+ }
+ service.submit(() -> {
+ try {
+ listener = new ServerSocket(s2sPort);
+ logger.info("s2s listener ready");
+ while (!listener.isClosed()) {
+ if (Thread.currentThread().isInterrupted()) break;
+ Socket socket = listener.accept();
+ ConnectionIn client = new ConnectionIn(this, socket);
+ addConnectionIn(client);
+ service.submit(client);
+ }
+ } catch (SocketException e) {
+ // shutdown
+ } catch (IOException | XmlPullParserException e) {
+ logger.warn("xmpp exception", e);
+ }
+ });
+ }
+
+ public void addConnectionIn(ConnectionIn c) {
+ c.setListener(this);
+ inConnections.add(c);
+ }
+
+ public void addConnectionOut(ConnectionOut c, Optional<Socket> socket) {
+ c.setListener(this);
+ outConnections.put(c, socket);
+ }
+
+ public void removeConnectionIn(ConnectionIn c) {
+ inConnections.remove(c);
+ }
+
+ public void removeConnectionOut(ConnectionOut c) {
+ outConnections.remove(c);
+ }
+
+ public String getFromCache(Jid to) {
+ final String[] cache = new String[1];
+ outCache.stream().filter(c -> c.hostname != null && c.hostname.equals(to)).findFirst().ifPresent(c -> {
+ cache[0] = c.xml;
+ outCache.remove(c);
+ });
+ return cache[0];
+ }
+
+ public Optional<ConnectionOut> getConnectionOut(Jid hostname, boolean needReady) {
+ return outConnections.keySet().stream().filter(c -> c.to != null &&
+ c.to.equals(hostname) && (!needReady || c.streamReady)).findFirst();
+ }
+
+ public Optional<ConnectionIn> getConnectionIn(String streamID) {
+ return inConnections.stream().filter(c -> c.streamID != null && c.streamID.equals(streamID)).findFirst();
+ }
+
+ public void sendOut(Jid hostname, String xml) {
+ boolean haveAnyConn = false;
+
+ ConnectionOut connOut = null;
+ for (ConnectionOut c : outConnections.keySet()) {
+ if (c.to != null && c.to.equals(hostname)) {
+ if (c.streamReady) {
+ connOut = c;
+ break;
+ } else {
+ haveAnyConn = true;
+ break;
+ }
+ }
+ }
+ if (connOut != null) {
+ connOut.send(xml);
+ return;
+ }
+
+ boolean haveCache = false;
+ for (CacheEntry c : outCache) {
+ if (c.hostname != null && c.hostname.equals(hostname)) {
+ c.xml += xml;
+ c.updated = Instant.now();
+ haveCache = true;
+ break;
+ }
+ }
+ if (!haveCache) {
+ outCache.add(new CacheEntry(hostname, xml));
+ }
+
+ if (!haveAnyConn && !closeFlag.get()) {
+ try {
+ createDialbackConnection(hostname.toEscapedString(), null, null);
+ } catch (Exception e) {
+ logger.warn("dialback error", e);
+ }
+ }
+ }
+
+ void createDialbackConnection(String to, String checkSID, String dbKey) throws Exception {
+ ConnectionOut connectionOut = new ConnectionOut(getJid(), Jid.of(to), null, null, checkSID, dbKey);
+ addConnectionOut(connectionOut, Optional.empty());
+ service.submit(() -> {
+ try {
+ Socket socket = new Socket();
+ socket.connect(DNSQueries.getServerAddress(to));
+ connectionOut.setInputStream(socket.getInputStream());
+ connectionOut.setOutputStream(socket.getOutputStream());
+ addConnectionOut(connectionOut, Optional.of(socket));
+ connectionOut.connect();
+ } catch (IOException e) {
+ logger.info("dialback to " + to + " exception", e);
+ }
+ });
+ }
+
+ public void startDialback(Jid from, String streamId, String dbKey) throws Exception {
+ Optional<ConnectionOut> c = getConnectionOut(from, false);
+ if (c.isPresent()) {
+ c.get().sendDialbackVerify(streamId, dbKey);
+ } else {
+ createDialbackConnection(from.toEscapedString(), streamId, dbKey);
+ }
+ }
+
+ public void addStanzaListener(StanzaListener listener) {
+ stanzaListeners.add(listener);
+ }
+
+ public void onStanzaReceived(String xmlValue) {
+ logger.info("S2S: {}", xmlValue);
+ Stanza stanza = parse(xmlValue);
+ stanzaListeners.forEach(l -> l.stanzaReceived(stanza));
+ }
+
+ public BasicXmppSession getSession() {
+ return session;
+ }
+
+ public List<ConnectionIn> getInConnections() {
+ return inConnections;
+ }
+
+ public Map<ConnectionOut, Optional<Socket>> getOutConnections() {
+ return outConnections;
+ }
+
+ @Override
+ public boolean isTlsAvailable() {
+ return tlsConfigured;
+ }
+
+ @Override
+ public void starttls(ConnectionIn connection) {
+ logger.debug("stream {} securing", connection.streamID);
+ connection.sendStanza("<proceed xmlns=\"" + Connection.NS_TLS + "\" />");
+ try {
+ connection.setSocket(sc.getSocketFactory().createSocket(connection.getSocket(), connection.getSocket().getInetAddress().getHostAddress(),
+ connection.getSocket().getPort(), false));
+ SSLSocket sslSocket = (SSLSocket) connection.getSocket();
+ sslSocket.addHandshakeCompletedListener(handshakeCompletedEvent -> {
+ try {
+ CertPath certPath = cf.generateCertPath(Arrays.asList(handshakeCompletedEvent.getPeerCertificates()));
+ cpv.validate(certPath, params);
+ connection.setTrusted(true);
+ logger.info("connection from {} is trusted", connection.from);
+ } catch (SSLPeerUnverifiedException | CertificateException | CertPathValidatorException | InvalidAlgorithmParameterException e) {
+ logger.info("connection from {} is NOT trusted, falling back to dialback", connection.from);
+ }
+ });
+ sslSocket.setUseClientMode(false);
+ sslSocket.setNeedClientAuth(true);
+ sslSocket.startHandshake();
+ connection.setSecured(true);
+ logger.debug("stream from {} secured", connection.streamID);
+ connection.restartParser();
+ } catch (XmlPullParserException | IOException sex) {
+ logger.warn("stream {} ssl error {}", connection.streamID, sex);
+ connection.sendStanza("<failure xmlns=\"" + Connection.NS_TLS + "\" />");
+ removeConnectionIn(connection);
+ connection.closeConnection();
+ }
+ }
+
+ @Override
+ public void proceed(ConnectionOut connection) {
+ try {
+ Socket socket = outConnections.get(connection).get();
+ socket = sc.getSocketFactory().createSocket(socket, socket.getInetAddress().getHostAddress(),
+ socket.getPort(), false);
+ SSLSocket sslSocket = (SSLSocket) socket;
+ sslSocket.addHandshakeCompletedListener(handshakeCompletedEvent -> {
+ try {
+ CertPath certPath = cf.generateCertPath(Arrays.asList(handshakeCompletedEvent.getPeerCertificates()));
+ cpv.validate(certPath, params);
+ connection.setTrusted(true);
+ logger.info("connection to {} is trusted", connection.to);
+ } catch (SSLPeerUnverifiedException | CertificateException | CertPathValidatorException | InvalidAlgorithmParameterException e) {
+ logger.info("connection to {} is NOT trusted, falling back to dialback", connection.to);
+ }
+ });
+ sslSocket.setNeedClientAuth(true);
+ sslSocket.startHandshake();
+ connection.setSecured(true);
+ logger.debug("stream to {} secured", connection.getStreamID());
+ connection.setInputStream(socket.getInputStream());
+ connection.setOutputStream(socket.getOutputStream());
+ connection.restartStream();
+ connection.sendOpenStream();
+ } catch (NoSuchElementException | XmlPullParserException | IOException sex) {
+ logger.error("s2s ssl error: {} {}, error {}", connection.to, connection.getStreamID(), sex);
+ connection.send("<failure xmlns=\"" + Connection.NS_TLS + "\" />");
+ removeConnectionOut(connection);
+ connection.logoff();
+ }
+ }
+
+ @Override
+ public void verify(ConnectionOut connection, String from, String type, String sid) {
+ if (from != null && from.equals(connection.to.toEscapedString()) && sid != null && !sid.isEmpty() && type != null) {
+ getConnectionIn(sid).ifPresent(c -> c.sendDialbackResult(Jid.of(from), type));
+ }
+ }
+
+ @Override
+ public void dialbackError(ConnectionOut connection, StreamError error) {
+ logger.warn("Stream error from {}: {}", connection.getStreamID(), error.getCondition());
+ removeConnectionOut(connection);
+ connection.logoff();
+ }
+
+ @Override
+ public void finished(ConnectionOut connection, boolean dirty) {
+ logger.warn("stream to {} {} finished, dirty={}", connection.to, connection.getStreamID(), dirty);
+ removeConnectionOut(connection);
+ connection.logoff();
+ }
+
+ @Override
+ public void exception(ConnectionOut connection, Exception ex) {
+ logger.error("s2s out exception: {} {}, exception {}", connection.to, connection.getStreamID(), ex);
+ removeConnectionOut(connection);
+ connection.logoff();
+ }
+
+ @Override
+ public void ready(ConnectionOut connection) {
+ logger.debug("stream to {} {} ready", connection.to, connection.getStreamID());
+ String cache = getFromCache(connection.to);
+ if (cache != null) {
+ logger.debug("stream to {} {} sending cache", connection.to, connection.getStreamID());
+ connection.send(cache);
+ }
+ }
+
+ @Override
+ public boolean securing(ConnectionOut connection) {
+ return tlsConfigured && !Arrays.asList(brokenSSLhosts).contains(connection.to.toEscapedString());
+ }
+
+ public Stanza parse(String xml) {
+ try {
+ Unmarshaller unmarshaller = session.createUnmarshaller();
+ return (Stanza)unmarshaller.unmarshal(new StringReader(xml));
+ } catch (JAXBException e) {
+ logger.error("JAXB exception", e);
+ }
+ return null;
+ }
+
+ public Jid getJid() {
+ return jid;
+ }
+ @Scheduled(fixedDelay = 10000)
+ public void cleanUp() {
+ Instant now = Instant.now();
+ outConnections.keySet().stream().filter(c -> Duration.between(c.getUpdated(), now).toMinutes() > TIMEOUT_MINUTES)
+ .forEach(c -> {
+ logger.info("closing idle outgoing connection to {}", c.to);
+ c.logoff();
+ outConnections.remove(c);
+ });
+
+ inConnections.stream().filter(c -> Duration.between(c.updated, now).toMinutes() > TIMEOUT_MINUTES)
+ .forEach(c -> {
+ logger.info("closing idle incoming connection from {}", c.from);
+ c.closeConnection();
+ inConnections.remove(c);
+ });
+ }
+ @PreDestroy
+ public void preDestroy() throws IOException {
+ closeFlag.set(true);
+ if (listener != null && !listener.isClosed()) {
+ listener.close();
+ }
+ service.shutdown();
+ logger.info("XMPP server destroyed");
+ }
+
+ public int getServerPort() {
+ return s2sPort;
+ }
+}