diff options
Diffstat (limited to 'src/com/juick/jabber')
-rw-r--r-- | src/com/juick/jabber/ws/Main.java | 518 |
1 files changed, 518 insertions, 0 deletions
diff --git a/src/com/juick/jabber/ws/Main.java b/src/com/juick/jabber/ws/Main.java new file mode 100644 index 00000000..f8cad85b --- /dev/null +++ b/src/com/juick/jabber/ws/Main.java @@ -0,0 +1,518 @@ +/* + * Juick + * Copyright (C) 2008-2011, Ugnich Anton + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + */ +package com.juick.jabber.ws; + +import com.juick.xmpp.JID; +import com.juick.xmpp.MessageListener; +import com.juick.xmpp.XmppConnection; +import com.juick.xmpp.XmppConnectionComponent; +import com.juick.xmpp.XmppListener; +import com.juick.xmpp.extensions.JuickMessage; +import java.io.FileInputStream; +import java.io.IOException; +import java.math.BigInteger; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.charset.Charset; +import java.security.MessageDigest; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Iterator; +import java.util.Properties; +import java.util.Vector; + +/** + * + * @author Ugnich Anton + */ +public class Main implements XmppListener, MessageListener { + + Connection sql; + XmppConnection xmpp; + Vector<SocketSubscribed> sockReplies = new Vector<SocketSubscribed>(); + Vector<SocketSubscribed> sockMessages = new Vector<SocketSubscribed>(); + Vector<SocketSubscribed> sockAll = new Vector<SocketSubscribed>(); + Selector sel; + + public static void main(String[] args) { + new Main().start(); + } + + public void start() { + try { + Properties conf = new Properties(); + conf.load(new FileInputStream("/etc/juick/ws.conf")); + + setupSql(conf.getProperty("mysql_username", ""), conf.getProperty("mysql_password", "")); + setupXmppComponent(conf.getProperty("xmpp_password", "")); + setupWsServer(); + } catch (Exception e) { + System.err.println(e); + } + } + + public void setupSql(String username, String password) { + try { + sql = DriverManager.getConnection("jdbc:mysql://localhost/juick?autoReconnect=true&user=" + username + "&password=" + password); + } catch (SQLException e) { + System.err.println(e); + } + } + + public void setupXmppComponent(String password) { + xmpp = new XmppConnectionComponent(new JID("ws.juick.com"), password, "127.0.0.1", 5347, false); + xmpp.addListener((XmppListener) this); + xmpp.addListener((MessageListener) this); + xmpp.start(); + } + + @Override + public void onConnectionFailed(String msg) { + System.err.println("XMPP onConnFailed " + msg); + } + + @Override + public void onAuth(String resource) { + System.err.println("XMPP onAuth " + resource); + } + + @Override + public void onAuthFailed(String message) { + System.err.println("XMPP onAuthFailed " + message); + } + + public void setupWsServer() { + try { + sel = Selector.open(); + ServerSocketChannel listensock = ServerSocketChannel.open(); + listensock.configureBlocking(false); + listensock.socket().bind(new InetSocketAddress(8080)); + listensock.register(sel, SelectionKey.OP_ACCEPT); + + while (true) { + sel.select(); + System.out.println("ONE"); + Iterator it = sel.selectedKeys().iterator(); + while (it.hasNext()) { + System.out.println("TWO"); + SelectionKey selKey = (SelectionKey) it.next(); + it.remove(); + if (selKey.isAcceptable()) { + ServerSocketChannel ssChannel = (ServerSocketChannel) selKey.channel(); + SocketChannel sChannel = ssChannel.accept(); + System.out.println(sChannel.socket().getRemoteSocketAddress().toString() + " ACCEPTED"); + sChannel.configureBlocking(false); + sChannel.register(sel, SelectionKey.OP_READ); + } else if (selKey.isReadable()) { + System.out.println("THREE"); + SocketChannel sChannel = (SocketChannel) selKey.channel(); + ByteBuffer buf = ByteBuffer.allocate(10240); + try { + if (sChannel.read(buf) > 0) { + buf.flip(); + CharBuffer charbuf = Charset.forName("ISO-8859-1").decode(buf); + if (charbuf.charAt(0) == 0 && charbuf.charAt(charbuf.length() - 1) == 0xFF) { + wsTextFrame(sChannel, charbuf.subSequence(1, charbuf.length() - 2)); + } else if (charbuf.charAt(0) == 'G' && charbuf.charAt(1) == 'E' && charbuf.charAt(2) == 'T' && charbuf.charAt(3) == ' ') { + wsHandshake(sChannel, buf); + } else { + System.out.println(sChannel.socket().getRemoteSocketAddress().toString() + " INVALID FRAME"); + System.out.println("FOUR"); + sChannel.close(); + selKey.cancel(); + } + } else { + sChannel.close(); + selKey.cancel(); + System.out.println("SIX"); + } + } catch (IOException e) { + System.err.println(e); + sChannel.close(); + System.out.println("FIVE"); + selKey.cancel(); + } + } + } + } + } catch (Exception e) { + System.err.println(e); + } + } + + public void wsHandshake(SocketChannel sock, ByteBuffer buf) throws Exception { + String hOrigin = null; + String hHost = null; + String hLocation = null; + String hSecWebSocketKey1 = null; + String hSecWebSocketKey2 = null; + String hCookie = null; + + buf.rewind(); + CharBuffer charbuf = Charset.forName("ISO-8859-1").decode(buf); + String headers[] = charbuf.toString().split("\r\n"); + for (int i = 0; i < headers.length; i++) { + String h[] = headers[i].split(" ", 2); + if (h.length == 2) { + if (h[0].equals("GET")) { + hLocation = headers[i].split(" ", 3)[1]; + } else if (h[0].equals("Origin:")) { + hOrigin = h[1]; + } else if (h[0].equals("Host:")) { + hHost = h[1]; + } else if (h[0].equals("Sec-WebSocket-Key1:")) { + hSecWebSocketKey1 = h[1]; + } else if (h[0].equals("Sec-WebSocket-Key2:")) { + hSecWebSocketKey2 = h[1]; + } else if (h[0].equals("Cookie:")) { + hCookie = h[1]; + } + } + } + + if (hOrigin == null || hHost == null || hLocation == null || hSecWebSocketKey1 == null || hSecWebSocketKey2 == null) { + System.err.println(sock.socket().getRemoteSocketAddress().toString() + " Invalid headers"); + sock.close(); + return; + } + + // Cookies + int UID = 0; + + if (hCookie != null) { + String hash = null; + + String cookies[] = hCookie.split("; "); + for (int i = 0; i < cookies.length; i++) { + String cookie[] = cookies[i].split("=", 2); + if (cookie[0].equals("hash")) { + hash = cookie[1]; + break; + } + } + + if (hash != null) { + UID = com.juick.server.UserQueries.getUIDbyHash(sql, hash); + } + } + + + // URL + + String loc[] = hLocation.split("/"); + int MID = 0; + if (hLocation.equals("/my") && UID > 0) { + sockMessages.add(new SocketSubscribed(sock, UID, 0)); + } else if (hLocation.equals("/all")) { + sockAll.add(new SocketSubscribed(sock, UID, 0)); + } else if ((loc.length == 2 || loc.length == 3) && loc[1].equals("replies")) { + if (loc.length == 2) { + sockReplies.add(new SocketSubscribed(sock, UID, 0)); + } else { + try { + MID = Integer.parseInt(loc[2]); + } catch (Exception e) { + } + if (MID > 0) { + sockReplies.add(new SocketSubscribed(sock, UID, MID)); + } else { + System.err.println(sock.socket().getRemoteSocketAddress().toString() + " Invalid MID"); + sock.close(); + return; + } + } + } else { + System.err.println(sock.socket().getRemoteSocketAddress().toString() + " Invalid location"); + sock.close(); + return; + } + + System.out.println(sock.socket().getRemoteSocketAddress().toString() + " HANDSHAKE (UID=" + UID + ", MID=" + MID + ")"); + + Long lSecNum1 = calcSecKeyNum(hSecWebSocketKey1); + Long lSecNum2 = calcSecKeyNum(hSecWebSocketKey2); + + BigInteger sec1 = new BigInteger(lSecNum1.toString()); + BigInteger sec2 = new BigInteger(lSecNum2.toString()); + + // concatenate 3 parts secNum1 + secNum2 + secKey (16 Bytes) + byte[] l128Bit = new byte[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + byte[] lTmp; + + lTmp = sec1.toByteArray(); + int lIdx = lTmp.length; + int lCnt = 0; + while (lIdx > 0 && lCnt < 4) { + lIdx--; + lCnt++; + l128Bit[4 - lCnt] = lTmp[lIdx]; + } + + lTmp = sec2.toByteArray(); + lIdx = lTmp.length; + lCnt = 0; + while (lIdx > 0 && lCnt < 4) { + lIdx--; + lCnt++; + l128Bit[8 - lCnt] = lTmp[lIdx]; + } + + buf.rewind(); + for (int i = 0; i < 8; i++) { + l128Bit[8 + i] = buf.get(buf.limit() - 8 + i); + } + + String outstr = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n" + + "Upgrade: WebSocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Origin: " + hOrigin + "\r\n" + + "Sec-WebSocket-Location: ws://" + hHost + hLocation + "\r\n" + + "Sec-WebSocket-Protocol: sample\r\n" + + "\r\n"; + ByteBuffer out = ByteBuffer.allocate(4096); + out.put(Charset.forName("ISO-8859-1").encode(outstr)); + out.put(MessageDigest.getInstance("MD5").digest(l128Bit)); + out.flip(); + + sock.write(out); + } + + private static long calcSecKeyNum(String aKey) { + StringBuilder lSB = new StringBuilder(); + // StringBuuffer lSB = new StringBuuffer(); + int lSpaces = 0; + for (int i = 0; i < aKey.length(); i++) { + char lC = aKey.charAt(i); + if (lC == ' ') { + lSpaces++; + } else if (lC >= '0' && lC <= '9') { + lSB.append(lC); + } + } + long lRes = -1; + if (lSpaces > 0) { + try { + lRes = Long.parseLong(lSB.toString()) / lSpaces; + // log.debug("Key: " + aKey + ", Numbers: " + lSB.toString() + + // ", Spaces: " + lSpaces + ", Result: " + lRes); + } catch (NumberFormatException ex) { + // use default result + } + } + return lRes; + } + + public void wsTextFrame(SocketChannel sock, CharSequence csbuf) { + String buf = csbuf.toString(); + if (buf.equals(" ")) { + ByteBuffer out = ByteBuffer.allocate(4); + out.put((byte) 0x00); + out.put((byte) 0x20); + out.put((byte) 0xFF); + out.flip(); + out.rewind(); + try { + sock.write(out); + } catch (IOException e) { + } + } else { + System.out.println(sock.socket().getRemoteSocketAddress().toString() + " DATA '" + buf + "'"); + } + } + + @Override + public void onMessage(com.juick.xmpp.Message msg) { + JuickMessage jmsg = (JuickMessage) msg.getChild(JuickMessage.XMLNS); + if (jmsg != null) { + if (jmsg.RID == 0) { + onJuickMessagePost(jmsg); + } else { + onJuickMessageReply(jmsg); + } + } + } + + private void onJuickMessagePost(com.juick.Message jmsg) { + String json = "{" + + "\"mid\":" + jmsg.MID + "," + + "\"user\":{" + "\"uid\":" + jmsg.User.UID + "," + "\"uname\":\"" + encloseJSON(jmsg.User.UName) + "\"}," + + "\"timestamp\":\"" + jmsg.Timestamp + "\"," + + "\"body\":\"" + encloseJSON(jmsg.Text) + "\""; + if (jmsg.tags.size() > 0) { + json += ",\"tags\":["; + for (int i = 0; i < jmsg.tags.size(); i++) { + if (i > 0) { + json += ","; + } + json += "\"" + encloseJSON((String) jmsg.tags.get(i)) + "\""; + } + json += "]"; + } + json += "}"; + + ByteBuffer out = ByteBuffer.allocate(10240); + out.put((byte) 0x00); + out.put(Charset.forName("UTF-8").encode(json)); + out.put((byte) 0xFF); + out.flip(); + + + String query = "SELECT suser_id FROM subscr_users WHERE user_id=" + jmsg.User.UID + " AND suser_id NOT IN (SELECT user_id FROM bl_tags INNER JOIN messages_tags USING(tag_id) WHERE message_id=" + jmsg.MID + ")"; + if (jmsg.Privacy < 0) { + query += " AND suser_id IN (SELECT wl_user_id FROM wl_users WHERE user_id=" + jmsg.User.UID + ")"; + } + + Statement stmt = null; + ResultSet rs = null; + try { + stmt = sql.createStatement(); + rs = stmt.executeQuery(query); + rs.beforeFirst(); + while (rs.next()) { + int UID = rs.getInt(1); + + for (int i = sockMessages.size() - 1; i >= 0; i--) { + SocketSubscribed ss = sockMessages.get(i); + if (ss.UID == UID) { + try { + out.rewind(); + ss.sock.write(out); + } catch (IOException e) { + sockMessages.remove(i); + try { + ss.sock.close(); + } catch (IOException ex) { + } + } + } + } + + if (jmsg.Privacy <= 0) { + for (int i = sockAll.size() - 1; i >= 0; i--) { + SocketSubscribed ss = sockAll.get(i); + if (ss.UID == UID) { + try { + out.rewind(); + ss.sock.write(out); + } catch (IOException e) { + sockAll.remove(i); + try { + ss.sock.close(); + } catch (IOException ex) { + } + } + } + } + } + + } + } catch (SQLException e) { + System.err.println(e); + } finally { + if (rs != null) { + try { + rs.close(); + } catch (SQLException e) { + } + } + if (stmt != null) { + try { + stmt.close(); + } catch (SQLException e) { + } + } + } + + // Send to all + if (jmsg.Privacy > 0) { + for (int i = sockAll.size() - 1; i >= 0; i--) { + SocketSubscribed ss = sockAll.get(i); + try { + out.rewind(); + ss.sock.write(out); + } catch (IOException e) { + sockAll.remove(i); + try { + ss.sock.close(); + } catch (IOException ex) { + } + } + } + } + + + + } + + private void onJuickMessageReply(com.juick.Message jmsg) { + String json = "{" + + "\"mid\":" + jmsg.MID + "," + + "\"rid\":" + jmsg.RID + "," + + "\"user\":{" + "\"uid\":" + jmsg.User.UID + "," + "\"uname\":\"" + encloseJSON(jmsg.User.UName) + "\"}," + + "\"timestamp\":\"" + jmsg.Timestamp + "\"," + + "\"body\":\"" + encloseJSON(jmsg.Text) + "\"" + + "}"; + + ByteBuffer out = ByteBuffer.allocate(10240); + out.put((byte) 0x00); + out.put(Charset.forName("UTF-8").encode(json)); + out.put((byte) 0xFF); + out.flip(); + + for (int i = sockReplies.size() - 1; i >= 0; i--) { + SocketSubscribed ss = sockReplies.get(i); + if (ss.MID == 0 || ss.MID == jmsg.MID) { + try { + out.rewind(); + ss.sock.write(out); + } catch (IOException e) { + sockReplies.remove(i); + try { + ss.sock.close(); + } catch (IOException ex) { + } + } + } + } + } + + public static String encloseJSON(String str) { + return str.replace("\"", """).replace("\\", "\\\\").replace("\n", "\\n"); + } +} + +class SocketSubscribed { + + public SocketChannel sock; + public int UID; + public int MID; + + public SocketSubscribed(SocketChannel sock, int UID, int MID) { + this.sock = sock; + this.UID = UID; + this.MID = MID; + } +} |