/*
* Juick
* Copyright (C) 2008-2011, Ugnich Anton
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
* published by the Free Software Foundation, either version 3 of the
* License, or (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see .
*/
package com.juick.jabber.ws;
import com.juick.xmpp.JID;
import com.juick.xmpp.MessageListener;
import com.juick.xmpp.XmppConnection;
import com.juick.xmpp.XmppConnectionComponent;
import com.juick.xmpp.XmppListener;
import com.juick.xmpp.extensions.JuickMessage;
import java.io.FileInputStream;
import java.io.IOException;
import java.math.BigInteger;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.CharBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.charset.Charset;
import java.security.MessageDigest;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.Iterator;
import java.util.Properties;
import java.util.Vector;
/**
*
* @author Ugnich Anton
*/
public class Main implements XmppListener, MessageListener {
Connection sql;
XmppConnection xmpp;
Vector sockReplies = new Vector();
Vector sockMessages = new Vector();
Vector sockAll = new Vector();
Selector sel;
public static void main(String[] args) {
new Main().start();
}
public void start() {
try {
Properties conf = new Properties();
conf.load(new FileInputStream("/etc/juick/ws.conf"));
setupSql(conf.getProperty("mysql_username", ""), conf.getProperty("mysql_password", ""));
setupXmppComponent(conf.getProperty("xmpp_password", ""));
setupWsServer();
} catch (Exception e) {
System.err.println(e);
}
}
public void setupSql(String username, String password) {
try {
sql = DriverManager.getConnection("jdbc:mysql://localhost/juick?autoReconnect=true&user=" + username + "&password=" + password);
} catch (SQLException e) {
System.err.println(e);
}
}
public void setupXmppComponent(String password) {
xmpp = new XmppConnectionComponent(new JID("ws.juick.com"), password, "127.0.0.1", 5347, false);
xmpp.addListener((XmppListener) this);
xmpp.addListener((MessageListener) this);
xmpp.start();
}
@Override
public void onConnectionFailed(String msg) {
System.err.println("XMPP onConnFailed " + msg);
}
@Override
public void onAuth(String resource) {
System.err.println("XMPP onAuth " + resource);
}
@Override
public void onAuthFailed(String message) {
System.err.println("XMPP onAuthFailed " + message);
}
public void setupWsServer() {
try {
sel = Selector.open();
ServerSocketChannel listensock = ServerSocketChannel.open();
listensock.configureBlocking(false);
listensock.socket().bind(new InetSocketAddress(8080));
listensock.register(sel, SelectionKey.OP_ACCEPT);
while (true) {
sel.select();
System.out.println("ONE");
Iterator it = sel.selectedKeys().iterator();
while (it.hasNext()) {
System.out.println("TWO");
SelectionKey selKey = (SelectionKey) it.next();
it.remove();
if (selKey.isAcceptable()) {
ServerSocketChannel ssChannel = (ServerSocketChannel) selKey.channel();
SocketChannel sChannel = ssChannel.accept();
System.out.println(sChannel.socket().getRemoteSocketAddress().toString() + " ACCEPTED");
sChannel.configureBlocking(false);
sChannel.register(sel, SelectionKey.OP_READ);
} else if (selKey.isReadable()) {
System.out.println("THREE");
SocketChannel sChannel = (SocketChannel) selKey.channel();
ByteBuffer buf = ByteBuffer.allocate(10240);
try {
if (sChannel.read(buf) > 0) {
buf.flip();
CharBuffer charbuf = Charset.forName("ISO-8859-1").decode(buf);
if (charbuf.charAt(0) == 0 && charbuf.charAt(charbuf.length() - 1) == 0xFF) {
wsTextFrame(sChannel, charbuf.subSequence(1, charbuf.length() - 2));
} else if (charbuf.charAt(0) == 'G' && charbuf.charAt(1) == 'E' && charbuf.charAt(2) == 'T' && charbuf.charAt(3) == ' ') {
wsHandshake(sChannel, buf);
} else {
System.out.println(sChannel.socket().getRemoteSocketAddress().toString() + " INVALID FRAME");
System.out.println("FOUR");
sChannel.close();
selKey.cancel();
}
} else {
sChannel.close();
selKey.cancel();
System.out.println("SIX");
}
} catch (IOException e) {
System.err.println(e);
sChannel.close();
System.out.println("FIVE");
selKey.cancel();
}
}
}
}
} catch (Exception e) {
System.err.println(e);
}
}
public void wsHandshake(SocketChannel sock, ByteBuffer buf) throws Exception {
String hOrigin = null;
String hHost = null;
String hLocation = null;
String hSecWebSocketKey1 = null;
String hSecWebSocketKey2 = null;
String hCookie = null;
buf.rewind();
CharBuffer charbuf = Charset.forName("ISO-8859-1").decode(buf);
String headers[] = charbuf.toString().split("\r\n");
for (int i = 0; i < headers.length; i++) {
String h[] = headers[i].split(" ", 2);
if (h.length == 2) {
if (h[0].equals("GET")) {
hLocation = headers[i].split(" ", 3)[1];
} else if (h[0].equals("Origin:")) {
hOrigin = h[1];
} else if (h[0].equals("Host:")) {
hHost = h[1];
} else if (h[0].equals("Sec-WebSocket-Key1:")) {
hSecWebSocketKey1 = h[1];
} else if (h[0].equals("Sec-WebSocket-Key2:")) {
hSecWebSocketKey2 = h[1];
} else if (h[0].equals("Cookie:")) {
hCookie = h[1];
}
}
}
if (hOrigin == null || hHost == null || hLocation == null || hSecWebSocketKey1 == null || hSecWebSocketKey2 == null) {
System.err.println(sock.socket().getRemoteSocketAddress().toString() + " Invalid headers");
sock.close();
return;
}
// Cookies
int UID = 0;
if (hCookie != null) {
String hash = null;
String cookies[] = hCookie.split("; ");
for (int i = 0; i < cookies.length; i++) {
String cookie[] = cookies[i].split("=", 2);
if (cookie[0].equals("hash")) {
hash = cookie[1];
break;
}
}
if (hash != null) {
UID = com.juick.server.UserQueries.getUIDbyHash(sql, hash);
}
}
// URL
String loc[] = hLocation.split("/");
int MID = 0;
if (hLocation.equals("/my") && UID > 0) {
sockMessages.add(new SocketSubscribed(sock, UID, 0));
} else if (hLocation.equals("/all")) {
sockAll.add(new SocketSubscribed(sock, UID, 0));
} else if ((loc.length == 2 || loc.length == 3) && loc[1].equals("replies")) {
if (loc.length == 2) {
sockReplies.add(new SocketSubscribed(sock, UID, 0));
} else {
try {
MID = Integer.parseInt(loc[2]);
} catch (Exception e) {
}
if (MID > 0) {
sockReplies.add(new SocketSubscribed(sock, UID, MID));
} else {
System.err.println(sock.socket().getRemoteSocketAddress().toString() + " Invalid MID");
sock.close();
return;
}
}
} else {
System.err.println(sock.socket().getRemoteSocketAddress().toString() + " Invalid location");
sock.close();
return;
}
System.out.println(sock.socket().getRemoteSocketAddress().toString() + " HANDSHAKE (UID=" + UID + ", MID=" + MID + ")");
Long lSecNum1 = calcSecKeyNum(hSecWebSocketKey1);
Long lSecNum2 = calcSecKeyNum(hSecWebSocketKey2);
BigInteger sec1 = new BigInteger(lSecNum1.toString());
BigInteger sec2 = new BigInteger(lSecNum2.toString());
// concatenate 3 parts secNum1 + secNum2 + secKey (16 Bytes)
byte[] l128Bit = new byte[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
byte[] lTmp;
lTmp = sec1.toByteArray();
int lIdx = lTmp.length;
int lCnt = 0;
while (lIdx > 0 && lCnt < 4) {
lIdx--;
lCnt++;
l128Bit[4 - lCnt] = lTmp[lIdx];
}
lTmp = sec2.toByteArray();
lIdx = lTmp.length;
lCnt = 0;
while (lIdx > 0 && lCnt < 4) {
lIdx--;
lCnt++;
l128Bit[8 - lCnt] = lTmp[lIdx];
}
buf.rewind();
for (int i = 0; i < 8; i++) {
l128Bit[8 + i] = buf.get(buf.limit() - 8 + i);
}
String outstr = "HTTP/1.1 101 Web Socket Protocol Handshake\r\n"
+ "Upgrade: WebSocket\r\n"
+ "Connection: Upgrade\r\n"
+ "Sec-WebSocket-Origin: " + hOrigin + "\r\n"
+ "Sec-WebSocket-Location: ws://" + hHost + hLocation + "\r\n"
+ "Sec-WebSocket-Protocol: sample\r\n"
+ "\r\n";
ByteBuffer out = ByteBuffer.allocate(4096);
out.put(Charset.forName("ISO-8859-1").encode(outstr));
out.put(MessageDigest.getInstance("MD5").digest(l128Bit));
out.flip();
sock.write(out);
}
private static long calcSecKeyNum(String aKey) {
StringBuilder lSB = new StringBuilder();
// StringBuuffer lSB = new StringBuuffer();
int lSpaces = 0;
for (int i = 0; i < aKey.length(); i++) {
char lC = aKey.charAt(i);
if (lC == ' ') {
lSpaces++;
} else if (lC >= '0' && lC <= '9') {
lSB.append(lC);
}
}
long lRes = -1;
if (lSpaces > 0) {
try {
lRes = Long.parseLong(lSB.toString()) / lSpaces;
// log.debug("Key: " + aKey + ", Numbers: " + lSB.toString() +
// ", Spaces: " + lSpaces + ", Result: " + lRes);
} catch (NumberFormatException ex) {
// use default result
}
}
return lRes;
}
public void wsTextFrame(SocketChannel sock, CharSequence csbuf) {
String buf = csbuf.toString();
if (buf.equals(" ")) {
ByteBuffer out = ByteBuffer.allocate(4);
out.put((byte) 0x00);
out.put((byte) 0x20);
out.put((byte) 0xFF);
out.flip();
out.rewind();
try {
sock.write(out);
} catch (IOException e) {
}
} else {
System.out.println(sock.socket().getRemoteSocketAddress().toString() + " DATA '" + buf + "'");
}
}
@Override
public void onMessage(com.juick.xmpp.Message msg) {
JuickMessage jmsg = (JuickMessage) msg.getChild(JuickMessage.XMLNS);
if (jmsg != null) {
if (jmsg.RID == 0) {
onJuickMessagePost(jmsg);
} else {
onJuickMessageReply(jmsg);
}
}
}
private void onJuickMessagePost(com.juick.Message jmsg) {
String json = "{"
+ "\"mid\":" + jmsg.MID + ","
+ "\"user\":{" + "\"uid\":" + jmsg.User.UID + "," + "\"uname\":\"" + encloseJSON(jmsg.User.UName) + "\"},"
+ "\"timestamp\":\"" + jmsg.Timestamp + "\","
+ "\"body\":\"" + encloseJSON(jmsg.Text) + "\"";
if (jmsg.tags.size() > 0) {
json += ",\"tags\":[";
for (int i = 0; i < jmsg.tags.size(); i++) {
if (i > 0) {
json += ",";
}
json += "\"" + encloseJSON((String) jmsg.tags.get(i)) + "\"";
}
json += "]";
}
json += "}";
ByteBuffer out = ByteBuffer.allocate(10240);
out.put((byte) 0x00);
out.put(Charset.forName("UTF-8").encode(json));
out.put((byte) 0xFF);
out.flip();
String query = "SELECT suser_id FROM subscr_users WHERE user_id=" + jmsg.User.UID + " AND suser_id NOT IN (SELECT user_id FROM bl_tags INNER JOIN messages_tags USING(tag_id) WHERE message_id=" + jmsg.MID + ")";
if (jmsg.Privacy < 0) {
query += " AND suser_id IN (SELECT wl_user_id FROM wl_users WHERE user_id=" + jmsg.User.UID + ")";
}
Statement stmt = null;
ResultSet rs = null;
try {
stmt = sql.createStatement();
rs = stmt.executeQuery(query);
rs.beforeFirst();
while (rs.next()) {
int UID = rs.getInt(1);
for (int i = sockMessages.size() - 1; i >= 0; i--) {
SocketSubscribed ss = sockMessages.get(i);
if (ss.UID == UID) {
try {
out.rewind();
ss.sock.write(out);
} catch (IOException e) {
sockMessages.remove(i);
try {
ss.sock.close();
} catch (IOException ex) {
}
}
}
}
if (jmsg.Privacy <= 0) {
for (int i = sockAll.size() - 1; i >= 0; i--) {
SocketSubscribed ss = sockAll.get(i);
if (ss.UID == UID) {
try {
out.rewind();
ss.sock.write(out);
} catch (IOException e) {
sockAll.remove(i);
try {
ss.sock.close();
} catch (IOException ex) {
}
}
}
}
}
}
} catch (SQLException e) {
System.err.println(e);
} finally {
if (rs != null) {
try {
rs.close();
} catch (SQLException e) {
}
}
if (stmt != null) {
try {
stmt.close();
} catch (SQLException e) {
}
}
}
// Send to all
if (jmsg.Privacy > 0) {
for (int i = sockAll.size() - 1; i >= 0; i--) {
SocketSubscribed ss = sockAll.get(i);
try {
out.rewind();
ss.sock.write(out);
} catch (IOException e) {
sockAll.remove(i);
try {
ss.sock.close();
} catch (IOException ex) {
}
}
}
}
}
private void onJuickMessageReply(com.juick.Message jmsg) {
String json = "{"
+ "\"mid\":" + jmsg.MID + ","
+ "\"rid\":" + jmsg.RID + ","
+ "\"replyto\":" + jmsg.ReplyTo + ","
+ "\"user\":{" + "\"uid\":" + jmsg.User.UID + "," + "\"uname\":\"" + encloseJSON(jmsg.User.UName) + "\"},"
+ "\"timestamp\":\"" + jmsg.Timestamp + "\","
+ "\"body\":\"" + encloseJSON(jmsg.Text) + "\""
+ "}";
ByteBuffer out = ByteBuffer.allocate(10240);
out.put((byte) 0x00);
out.put(Charset.forName("UTF-8").encode(json));
out.put((byte) 0xFF);
out.flip();
for (int i = sockReplies.size() - 1; i >= 0; i--) {
SocketSubscribed ss = sockReplies.get(i);
if (ss.MID == 0 || ss.MID == jmsg.MID) {
try {
out.rewind();
ss.sock.write(out);
} catch (IOException e) {
sockReplies.remove(i);
try {
ss.sock.close();
} catch (IOException ex) {
}
}
}
}
}
public static String encloseJSON(String str) {
return str.replace("\"", """).replace("\\", "\\\\").replace("\n", "\\n");
}
}
class SocketSubscribed {
public SocketChannel sock;
public int UID;
public int MID;
public SocketSubscribed(SocketChannel sock, int UID, int MID) {
this.sock = sock;
this.UID = UID;
this.MID = MID;
}
}