/* * Juick * Copyright (C) 2008-2011, Ugnich Anton * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as * published by the Free Software Foundation, either version 3 of the * License, or (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Affero General Public License for more details. * * You should have received a copy of the GNU Affero General Public License * along with this program. If not, see . */ package com.juick.jabber.ws; import com.juick.xmpp.JID; import com.juick.xmpp.MessageListener; import com.juick.xmpp.XmppConnection; import com.juick.xmpp.XmppConnectionComponent; import com.juick.xmpp.XmppListener; import com.juick.xmpp.extensions.JuickMessage; import java.io.FileInputStream; import java.io.IOException; import java.math.BigInteger; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.CharBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.nio.charset.Charset; import java.security.MessageDigest; import java.sql.Connection; import java.sql.DriverManager; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Statement; import java.util.Iterator; import java.util.Properties; import java.util.Vector; /** * * @author Ugnich Anton */ public class Main implements XmppListener, MessageListener { Connection sql; XmppConnection xmpp; Vector sockReplies = new Vector(); Vector sockMessages = new Vector(); Vector sockAll = new Vector(); Selector sel; public static void main(String[] args) { new Main().start(); } public void start() { try { Properties conf = new Properties(); conf.load(new FileInputStream("/etc/juick/ws.conf")); setupSql(conf.getProperty("mysql_username", ""), conf.getProperty("mysql_password", "")); setupXmppComponent(conf.getProperty("xmpp_password", "")); setupWsServer(); } catch (Exception e) { System.err.println(e); } } public void setupSql(String username, String password) { try { sql = DriverManager.getConnection("jdbc:mysql://localhost/juick?autoReconnect=true&user=" + username + "&password=" + password); } catch (SQLException e) { System.err.println(e); } } public void setupXmppComponent(String password) { xmpp = new XmppConnectionComponent(new JID("ws.juick.com"), password, "127.0.0.1", 5347, false); xmpp.addListener((XmppListener) this); xmpp.addListener((MessageListener) this); xmpp.start(); } @Override public void onConnectionFailed(String msg) { System.err.println("XMPP onConnFailed " + msg); } @Override public void onAuth(String resource) { System.err.println("XMPP onAuth " + resource); } @Override public void onAuthFailed(String message) { System.err.println("XMPP onAuthFailed " + message); } public void setupWsServer() { try { sel = Selector.open(); ServerSocketChannel listensock = ServerSocketChannel.open(); listensock.configureBlocking(false); listensock.socket().bind(new InetSocketAddress(8080)); listensock.register(sel, SelectionKey.OP_ACCEPT); while (true) { sel.select(); System.out.println("ONE"); Iterator it = sel.selectedKeys().iterator(); while (it.hasNext()) { System.out.println("TWO"); SelectionKey selKey = (SelectionKey) it.next(); it.remove(); if (selKey.isAcceptable()) { ServerSocketChannel ssChannel = (ServerSocketChannel) selKey.channel(); SocketChannel sChannel = ssChannel.accept(); System.out.println(sChannel.socket().getRemoteSocketAddress().toString() + " ACCEPTED"); sChannel.configureBlocking(false); sChannel.register(sel, SelectionKey.OP_READ); } else if (selKey.isReadable()) { System.out.println("THREE"); SocketChannel sChannel = (SocketChannel) selKey.channel(); ByteBuffer buf = ByteBuffer.allocate(10240); try { if (sChannel.read(buf) > 0) { buf.flip(); CharBuffer charbuf = Charset.forName("ISO-8859-1").decode(buf); if (charbuf.charAt(0) == 0 && charbuf.charAt(charbuf.length() - 1) == 0xFF) { wsTextFrame(sChannel, charbuf.subSequence(1, charbuf.length() - 2)); } else if (charbuf.charAt(0) == 'G' && charbuf.charAt(1) == 'E' && charbuf.charAt(2) == 'T' && charbuf.charAt(3) == ' ') { wsHandshake(sChannel, buf); } else { System.out.println(sChannel.socket().getRemoteSocketAddress().toString() + " INVALID FRAME"); System.out.println("FOUR"); sChannel.close(); selKey.cancel(); } } else { sChannel.close(); selKey.cancel(); System.out.println("SIX"); } } catch (IOException e) { System.err.println(e); sChannel.close(); System.out.println("FIVE"); selKey.cancel(); } } } } } catch (Exception e) { System.err.println(e); } } public void wsHandshake(SocketChannel sock, ByteBuffer buf) throws Exception { String hOrigin = null; String hHost = null; String hLocation = null; String hSecWebSocketKey1 = null; String hSecWebSocketKey2 = null; String hCookie = null; buf.rewind(); CharBuffer charbuf = Charset.forName("ISO-8859-1").decode(buf); String headers[] = charbuf.toString().split("\r\n"); for (int i = 0; i < headers.length; i++) { String h[] = headers[i].split(" ", 2); if (h.length == 2) { if (h[0].equals("GET")) { hLocation = headers[i].split(" ", 3)[1]; } else if (h[0].equals("Origin:")) { hOrigin = h[1]; } else if (h[0].equals("Host:")) { hHost = h[1]; } else if (h[0].equals("Sec-WebSocket-Key1:")) { hSecWebSocketKey1 = h[1]; } else if (h[0].equals("Sec-WebSocket-Key2:")) { hSecWebSocketKey2 = h[1]; } else if (h[0].equals("Cookie:")) { hCookie = h[1]; } } } if (hOrigin == null || hHost == null || hLocation == null || hSecWebSocketKey1 == null || hSecWebSocketKey2 == null) { System.err.println(sock.socket().getRemoteSocketAddress().toString() + " Invalid headers"); sock.close(); return; } // Cookies int UID = 0; if (hCookie != null) { String hash = null; String cookies[] = hCookie.split("; "); for (int i = 0; i < cookies.length; i++) { String cookie[] = cookies[i].split("=", 2); if (cookie[0].equals("hash")) { hash = cookie[1]; break; } } if (hash != null) { UID = com.juick.server.UserQueries.getUIDbyHash(sql, hash); } } // URL String loc[] = hLocation.split("/"); int MID = 0; if (hLocation.equals("/my") && UID > 0) { sockMessages.add(new SocketSubscribed(sock, UID, 0)); } else if (hLocation.equals("/all")) { sockAll.add(new SocketSubscribed(sock, UID, 0)); } else if ((loc.length == 2 || loc.length == 3) && loc[1].equals("replies")) { if (loc.length == 2) { sockReplies.add(new SocketSubscribed(sock, UID, 0)); } else { try { MID = Integer.parseInt(loc[2]); } catch (Exception e) { } if (MID > 0) { sockReplies.add(new SocketSubscribed(sock, UID, MID)); } else { System.err.println(sock.socket().getRemoteSocketAddress().toString() + " Invalid MID"); sock.close(); return; } } } else { System.err.println(sock.socket().getRemoteSocketAddress().toString() + " Invalid location"); sock.close(); return; } System.out.println(sock.socket().getRemoteSocketAddress().toString() + " HANDSHAKE (UID=" + UID + ", MID=" + MID + ")"); Long lSecNum1 = calcSecKeyNum(hSecWebSocketKey1); Long lSecNum2 = calcSecKeyNum(hSecWebSocketKey2); BigInteger sec1 = new BigInteger(lSecNum1.toString()); BigInteger sec2 = new BigInteger(lSecNum2.toString()); // concatenate 3 parts secNum1 + secNum2 + secKey (16 Bytes) byte[] l128Bit = new byte[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; byte[] lTmp; lTmp = sec1.toByteArray(); int lIdx = lTmp.length; int lCnt = 0; while (lIdx > 0 && lCnt < 4) { lIdx--; lCnt++; l128Bit[4 - lCnt] = lTmp[lIdx]; } lTmp = sec2.toByteArray(); lIdx = lTmp.length; lCnt = 0; while (lIdx > 0 && lCnt < 4) { lIdx--; lCnt++; l128Bit[8 - lCnt] = lTmp[lIdx]; } buf.rewind(); for (int i = 0; i < 8; i++) { l128Bit[8 + i] = buf.get(buf.limit() - 8 + i); } String outstr = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n" + "Upgrade: WebSocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Origin: " + hOrigin + "\r\n" + "Sec-WebSocket-Location: ws://" + hHost + hLocation + "\r\n" + "Sec-WebSocket-Protocol: sample\r\n" + "\r\n"; ByteBuffer out = ByteBuffer.allocate(4096); out.put(Charset.forName("ISO-8859-1").encode(outstr)); out.put(MessageDigest.getInstance("MD5").digest(l128Bit)); out.flip(); sock.write(out); } private static long calcSecKeyNum(String aKey) { StringBuilder lSB = new StringBuilder(); // StringBuuffer lSB = new StringBuuffer(); int lSpaces = 0; for (int i = 0; i < aKey.length(); i++) { char lC = aKey.charAt(i); if (lC == ' ') { lSpaces++; } else if (lC >= '0' && lC <= '9') { lSB.append(lC); } } long lRes = -1; if (lSpaces > 0) { try { lRes = Long.parseLong(lSB.toString()) / lSpaces; // log.debug("Key: " + aKey + ", Numbers: " + lSB.toString() + // ", Spaces: " + lSpaces + ", Result: " + lRes); } catch (NumberFormatException ex) { // use default result } } return lRes; } public void wsTextFrame(SocketChannel sock, CharSequence csbuf) { String buf = csbuf.toString(); if (buf.equals(" ")) { ByteBuffer out = ByteBuffer.allocate(4); out.put((byte) 0x00); out.put((byte) 0x20); out.put((byte) 0xFF); out.flip(); out.rewind(); try { sock.write(out); } catch (IOException e) { } } else { System.out.println(sock.socket().getRemoteSocketAddress().toString() + " DATA '" + buf + "'"); } } @Override public void onMessage(com.juick.xmpp.Message msg) { JuickMessage jmsg = (JuickMessage) msg.getChild(JuickMessage.XMLNS); if (jmsg != null) { if (jmsg.RID == 0) { onJuickMessagePost(jmsg); } else { onJuickMessageReply(jmsg); } } } private void onJuickMessagePost(com.juick.Message jmsg) { String json = "{" + "\"mid\":" + jmsg.MID + "," + "\"user\":{" + "\"uid\":" + jmsg.User.UID + "," + "\"uname\":\"" + encloseJSON(jmsg.User.UName) + "\"}," + "\"timestamp\":\"" + jmsg.Timestamp + "\"," + "\"body\":\"" + encloseJSON(jmsg.Text) + "\""; if (jmsg.tags.size() > 0) { json += ",\"tags\":["; for (int i = 0; i < jmsg.tags.size(); i++) { if (i > 0) { json += ","; } json += "\"" + encloseJSON((String) jmsg.tags.get(i)) + "\""; } json += "]"; } json += "}"; ByteBuffer out = ByteBuffer.allocate(10240); out.put((byte) 0x00); out.put(Charset.forName("UTF-8").encode(json)); out.put((byte) 0xFF); out.flip(); String query = "SELECT suser_id FROM subscr_users WHERE user_id=" + jmsg.User.UID + " AND suser_id NOT IN (SELECT user_id FROM bl_tags INNER JOIN messages_tags USING(tag_id) WHERE message_id=" + jmsg.MID + ")"; if (jmsg.Privacy < 0) { query += " AND suser_id IN (SELECT wl_user_id FROM wl_users WHERE user_id=" + jmsg.User.UID + ")"; } Statement stmt = null; ResultSet rs = null; try { stmt = sql.createStatement(); rs = stmt.executeQuery(query); rs.beforeFirst(); while (rs.next()) { int UID = rs.getInt(1); for (int i = sockMessages.size() - 1; i >= 0; i--) { SocketSubscribed ss = sockMessages.get(i); if (ss.UID == UID) { try { out.rewind(); ss.sock.write(out); } catch (IOException e) { sockMessages.remove(i); try { ss.sock.close(); } catch (IOException ex) { } } } } if (jmsg.Privacy <= 0) { for (int i = sockAll.size() - 1; i >= 0; i--) { SocketSubscribed ss = sockAll.get(i); if (ss.UID == UID) { try { out.rewind(); ss.sock.write(out); } catch (IOException e) { sockAll.remove(i); try { ss.sock.close(); } catch (IOException ex) { } } } } } } } catch (SQLException e) { System.err.println(e); } finally { if (rs != null) { try { rs.close(); } catch (SQLException e) { } } if (stmt != null) { try { stmt.close(); } catch (SQLException e) { } } } // Send to all if (jmsg.Privacy > 0) { for (int i = sockAll.size() - 1; i >= 0; i--) { SocketSubscribed ss = sockAll.get(i); try { out.rewind(); ss.sock.write(out); } catch (IOException e) { sockAll.remove(i); try { ss.sock.close(); } catch (IOException ex) { } } } } } private void onJuickMessageReply(com.juick.Message jmsg) { String json = "{" + "\"mid\":" + jmsg.MID + "," + "\"rid\":" + jmsg.RID + "," + "\"replyto\":" + jmsg.ReplyTo + "," + "\"user\":{" + "\"uid\":" + jmsg.User.UID + "," + "\"uname\":\"" + encloseJSON(jmsg.User.UName) + "\"}," + "\"timestamp\":\"" + jmsg.Timestamp + "\"," + "\"body\":\"" + encloseJSON(jmsg.Text) + "\"" + "}"; ByteBuffer out = ByteBuffer.allocate(10240); out.put((byte) 0x00); out.put(Charset.forName("UTF-8").encode(json)); out.put((byte) 0xFF); out.flip(); for (int i = sockReplies.size() - 1; i >= 0; i--) { SocketSubscribed ss = sockReplies.get(i); if (ss.MID == 0 || ss.MID == jmsg.MID) { try { out.rewind(); ss.sock.write(out); } catch (IOException e) { sockReplies.remove(i); try { ss.sock.close(); } catch (IOException ex) { } } } } } public static String encloseJSON(String str) { return str.replace("\"", """).replace("\\", "\\\\").replace("\n", "\\n"); } } class SocketSubscribed { public SocketChannel sock; public int UID; public int MID; public SocketSubscribed(SocketChannel sock, int UID, int MID) { this.sock = sock; this.UID = UID; this.MID = MID; } }