/* * Copyright (C) 2008-2017, Juick * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . */ package com.juick.server; import com.juick.server.xmpp.s2s.*; import com.juick.service.UserService; import com.juick.xmpp.extensions.StreamError; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Value; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Component; import org.xmlpull.v1.XmlPullParserException; import rocks.xmpp.addr.Jid; import rocks.xmpp.core.stanza.model.Stanza; import javax.annotation.PostConstruct; import javax.annotation.PreDestroy; import javax.inject.Inject; import javax.net.ssl.*; import javax.xml.bind.JAXBException; import javax.xml.bind.Unmarshaller; import java.io.FileInputStream; import java.io.IOException; import java.io.InputStream; import java.io.StringReader; import java.net.ServerSocket; import java.net.Socket; import java.net.SocketException; import java.security.KeyStore; import java.security.KeyStoreException; import java.security.SecureRandom; import java.time.Duration; import java.time.Instant; import java.util.*; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicBoolean; /** * @author ugnich */ @Component public class XMPPServer implements ConnectionListener, AutoCloseable { private static final Logger logger = LoggerFactory.getLogger("com.juick.server.xmpp"); private static final int TIMEOUT_MINUTES = 15; @Inject public ExecutorService service; @Value("${hostname:localhost}") private Jid jid; @Value("${s2s_port:5269}") private int s2sPort; @Value("${keystore:juick.p12}") public String keystore; @Value("${keystore_password:secret}") public String keystorePassword; @Value("${broken_ssl_hosts:}") public String[] brokenSSLhosts; @Value("${banned_hosts:}") public String[] bannedHosts; private final List inConnections = new CopyOnWriteArrayList<>(); private final Map> outConnections = new ConcurrentHashMap<>(); private final List outCache = new CopyOnWriteArrayList<>(); private final List stanzaListeners = new CopyOnWriteArrayList<>(); private final AtomicBoolean closeFlag = new AtomicBoolean(false); SSLContext sc; private TrustManager[] trustAllCerts = new TrustManager[]{ new X509TrustManager() { public void checkClientTrusted(java.security.cert.X509Certificate[] certs, String authType) { } public void checkServerTrusted(java.security.cert.X509Certificate[] certs, String authType) { } public java.security.cert.X509Certificate[] getAcceptedIssuers() { return null; } } }; private boolean tlsConfigured = false; private ServerSocket listener; @Inject private BasicXmppSession session; @Inject private UserService userService; @PostConstruct public void init() throws KeyStoreException { closeFlag.set(false); KeyStore ks = KeyStore.getInstance("PKCS12"); try (InputStream ksIs = new FileInputStream(keystore)) { ks.load(ksIs, keystorePassword.toCharArray()); KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory .getDefaultAlgorithm()); kmf.init(ks, keystorePassword.toCharArray()); sc = SSLContext.getInstance("TLSv1.2"); sc.init(kmf.getKeyManagers(), trustAllCerts, new SecureRandom()); tlsConfigured = true; } catch (Exception e) { logger.warn("tls unavailable"); } service.submit(() -> { try { listener = new ServerSocket(s2sPort); logger.info("s2s listener ready"); while (!listener.isClosed()) { if (Thread.currentThread().isInterrupted()) break; Socket socket = listener.accept(); ConnectionIn client = new ConnectionIn(this, socket); addConnectionIn(client); service.submit(client); } } catch (SocketException e) { // shutdown } catch (IOException | XmlPullParserException e) { logger.warn("xmpp exception", e); } }); } @Override public void close() throws Exception { if (listener != null && !listener.isClosed()) { listener.close(); } outConnections.forEach((c, s) -> { c.logoff(); outConnections.remove(c); }); inConnections.forEach(c -> { c.closeConnection(); inConnections.remove(c); }); service.shutdown(); logger.info("XMPP server destroyed"); } public void addConnectionIn(ConnectionIn c) { c.setListener(this); inConnections.add(c); } public void addConnectionOut(ConnectionOut c, Optional socket) { c.setListener(this); outConnections.put(c, socket); } public void removeConnectionIn(ConnectionIn c) { inConnections.remove(c); } public void removeConnectionOut(ConnectionOut c) { outConnections.remove(c); } public String getFromCache(Jid to) { final String[] cache = new String[1]; outCache.stream().filter(c -> c.hostname != null && c.hostname.equals(to)).findFirst().ifPresent(c -> { cache[0] = c.xml; outCache.remove(c); }); return cache[0]; } public Optional getConnectionOut(Jid hostname, boolean needReady) { return outConnections.keySet().stream().filter(c -> c.to != null && c.to.equals(hostname) && (!needReady || c.streamReady)).findFirst(); } public Optional getConnectionIn(String streamID) { return inConnections.stream().filter(c -> c.streamID != null && c.streamID.equals(streamID)).findFirst(); } public void sendOut(Jid hostname, String xml) { boolean haveAnyConn = false; ConnectionOut connOut = null; for (ConnectionOut c : outConnections.keySet()) { if (c.to != null && c.to.equals(hostname)) { if (c.streamReady) { connOut = c; break; } else { haveAnyConn = true; break; } } } if (connOut != null) { connOut.send(xml); return; } boolean haveCache = false; for (CacheEntry c : outCache) { if (c.hostname != null && c.hostname.equals(hostname)) { c.xml += xml; c.updated = Instant.now(); haveCache = true; break; } } if (!haveCache) { outCache.add(new CacheEntry(hostname, xml)); } if (!haveAnyConn && !closeFlag.get()) { try { createDialbackConnection(hostname.toEscapedString(), null, null); } catch (Exception e) { logger.warn("dialback error", e); } } } void createDialbackConnection(String to, String checkSID, String dbKey) throws Exception { ConnectionOut connectionOut = new ConnectionOut(getJid(), Jid.of(to), null, null, checkSID, dbKey); addConnectionOut(connectionOut, Optional.empty()); service.submit(() -> { try { Socket socket = new Socket(); socket.connect(DNSQueries.getServerAddress(to)); connectionOut.setInputStream(socket.getInputStream()); connectionOut.setOutputStream(socket.getOutputStream()); addConnectionOut(connectionOut, Optional.of(socket)); connectionOut.connect(); } catch (IOException e) { logger.info("dialback to " + to + " exception", e); } }); } public void startDialback(Jid from, String streamId, String dbKey) throws Exception { Optional c = getConnectionOut(from, false); if (c.isPresent()) { c.get().sendDialbackVerify(streamId, dbKey); } else { createDialbackConnection(from.toEscapedString(), streamId, dbKey); } } public void addStanzaListener(StanzaListener listener) { stanzaListeners.add(listener); } public void onStanzaReceived(String xmlValue) { logger.info("S2S: {}", xmlValue); Stanza stanza = parse(xmlValue); stanzaListeners.forEach(l -> l.stanzaReceived(stanza)); } public BasicXmppSession getSession() { return session; } public List getInConnections() { return inConnections; } public Map> getOutConnections() { return outConnections; } @Override public boolean isTlsAvailable() { return tlsConfigured; } @Override public void starttls(ConnectionIn connection) { logger.debug("stream {} securing", connection.streamID); connection.sendStanza(""); try { connection.setSocket(sc.getSocketFactory().createSocket(connection.getSocket(), connection.getSocket().getInetAddress().getHostAddress(), connection.getSocket().getPort(), true)); ((SSLSocket) connection.getSocket()).setUseClientMode(false); ((SSLSocket) connection.getSocket()).startHandshake(); connection.setSecured(true); logger.debug("stream {} secured", connection.streamID); connection.restartParser(); } catch (XmlPullParserException | IOException sex) { logger.warn("stream {} ssl error {}", connection.streamID, sex); connection.sendStanza(""); removeConnectionIn(connection); connection.closeConnection(); } } @Override public void proceed(ConnectionOut connection) { try { Socket socket = outConnections.get(connection).get(); socket = sc.getSocketFactory().createSocket(socket, socket.getInetAddress().getHostAddress(), socket.getPort(), true); ((SSLSocket) socket).startHandshake(); connection.setSecured(true); logger.debug("stream {} secured", connection.getStreamID()); connection.setInputStream(socket.getInputStream()); connection.setOutputStream(socket.getOutputStream()); connection.restartStream(); connection.sendOpenStream(); } catch (NoSuchElementException | XmlPullParserException | IOException sex) { logger.error("s2s ssl error: {} {}, error {}", connection.to, connection.getStreamID(), sex); connection.send(""); removeConnectionOut(connection); connection.logoff(); } } @Override public void verify(ConnectionOut connection, String from, String type, String sid) { if (from != null && from.equals(connection.to.toEscapedString()) && sid != null && !sid.isEmpty() && type != null) { getConnectionIn(sid).ifPresent(c -> c.sendDialbackResult(Jid.of(from), type)); } } @Override public void dialbackError(ConnectionOut connection, StreamError error) { logger.warn("Stream error from {}: {}", connection.getStreamID(), error.getCondition()); removeConnectionOut(connection); connection.logoff(); } @Override public void finished(ConnectionOut connection, boolean dirty) { logger.warn("stream to {} {} finished, dirty={}", connection.to, connection.getStreamID(), dirty); removeConnectionOut(connection); connection.logoff(); } @Override public void exception(ConnectionOut connection, Exception ex) { logger.error("s2s out exception: {} {}, exception {}", connection.to, connection.getStreamID(), ex); removeConnectionOut(connection); connection.logoff(); } @Override public void ready(ConnectionOut connection) { logger.debug("stream to {} {} ready", connection.to, connection.getStreamID()); String cache = getFromCache(connection.to); if (cache != null) { logger.debug("stream to {} {} sending cache", connection.to, connection.getStreamID()); connection.send(cache); } } @Override public boolean securing(ConnectionOut connection) { return tlsConfigured && !Arrays.asList(brokenSSLhosts).contains(connection.to.toEscapedString()); } public Stanza parse(String xml) { try { Unmarshaller unmarshaller = session.createUnmarshaller(); return (Stanza)unmarshaller.unmarshal(new StringReader(xml)); } catch (JAXBException e) { logger.error("JAXB exception", e); } return null; } public Jid getJid() { return jid; } @Scheduled(fixedDelay = 10000) public void cleanUp() { Instant now = Instant.now(); outConnections.keySet().stream().filter(c -> Duration.between(c.getUpdated(), now).toMinutes() > TIMEOUT_MINUTES) .forEach(c -> { logger.info("closing idle outgoing connection to {}", c.to); c.logoff(); outConnections.remove(c); }); inConnections.stream().filter(c -> Duration.between(c.updated, now).toMinutes() > TIMEOUT_MINUTES) .forEach(c -> { logger.info("closing idle incoming connection from {}", c.from); c.closeConnection(); inConnections.remove(c); }); } @PreDestroy public void preDestroy() { closeFlag.set(true); } }