package com.juick.jabber.ws; import com.juick.server.MessagesQueries; import com.juick.server.UserQueries; import com.juick.xmpp.utils.Base64; import org.springframework.jdbc.core.JdbcTemplate; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; import java.nio.CharBuffer; import java.nio.channels.SelectionKey; import java.nio.channels.Selector; import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.nio.charset.Charset; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; import java.util.Iterator; import java.util.logging.Level; import java.util.logging.Logger; /** * * @author ugnich */ public class WSData implements Runnable { private static final Logger logger = Logger.getLogger("Websockets"); static final String WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; JdbcTemplate sql; public Selector sel; public WSData(JdbcTemplate sql) { this.sql = sql; } @Override public void run() { try { sel = Selector.open(); ServerSocketChannel listensock = ServerSocketChannel.open(); listensock.configureBlocking(false); listensock.socket().bind(new InetSocketAddress(8081)); listensock.register(sel, SelectionKey.OP_ACCEPT); while (true) { sel.select(); Iterator it = sel.selectedKeys().iterator(); while (it.hasNext()) { SelectionKey selKey = it.next(); it.remove(); if (selKey.isAcceptable()) { ServerSocketChannel ssChannel = (ServerSocketChannel) selKey.channel(); SocketChannel sChannel = ssChannel.accept(); sChannel.configureBlocking(false); sChannel.register(sel, SelectionKey.OP_READ); logger.info(sChannel.socket().getRemoteSocketAddress().toString() + " ACCEPTED"); } else if (selKey.isReadable()) { SocketChannel sChannel = (SocketChannel) selKey.channel(); ByteBuffer buf = ByteBuffer.allocate(10240); try { int readbytes = sChannel.read(buf); if (readbytes > 0) { buf.flip(); CharBuffer charbuf = Charset.forName("ISO-8859-1").decode(buf); buf.rewind(); switch (buf.get(0)) { case (byte) 0x89: // PING updateSocketTS(sChannel); wsPing(sChannel); break; case (byte) 0x8A: // PONG updateSocketTS(sChannel); break; case (byte) 0x81: // TEXT FRAME updateSocketTS(sChannel); wsTextFrame(sChannel, buf); break; case (byte) 'G': // HTTP updateSocketTS(sChannel); wsHandshake(sChannel, buf); break; case (byte) 0x88: // CONNECTION CLOSE throw new IOException(sChannel.socket().getRemoteSocketAddress().toString() + " CONNECTION CLOSE"); } } else if (readbytes < 0) { throw new IOException(sChannel.socket().getRemoteSocketAddress().toString() + " END OF STREAM"); } } catch (IOException e) { logger.log(Level.SEVERE, "websocket exception", e); sChannel.socket().close(); sChannel.close(); selKey.cancel(); } } } } } catch (Exception e) { logger.log(Level.SEVERE, "websocket exception", e); } } public void wsHandshake(SocketChannel sock, ByteBuffer buf) throws Exception { String hOrigin = null; String hHost = null; String hLocation = null; String hSecWebSocketKey = null; String hSecWebSocketVersion = null; String hXRealIP = null; buf.rewind(); CharBuffer charbuf = Charset.forName("ISO-8859-1").decode(buf); String headers[] = charbuf.toString().split("\r\n"); for (int i = 0; i < headers.length; i++) { String h[] = headers[i].split(" ", 2); if (h.length == 2) { if (h[0].equals("GET")) { hLocation = headers[i].split(" ", 3)[1]; } else if (h[0].equals("Origin:")) { hOrigin = h[1]; } else if (h[0].equals("Host:")) { hHost = h[1]; } else if (h[0].equals("Sec-WebSocket-Key:")) { hSecWebSocketKey = h[1]; } else if (h[0].equals("Sec-WebSocket-Version:")) { hSecWebSocketVersion = h[1]; } else if (h[0].equals("X-Real-IP:")) { hXRealIP = h[1]; } } } if (hOrigin == null || hHost == null || hLocation == null || hSecWebSocketKey == null || hSecWebSocketVersion == null || !hSecWebSocketVersion.equals("13")) { throw new IOException(sock.socket().getRemoteSocketAddress().toString() + " Invalid headers"); } // Auth int VUID = 0; int hashloc = hLocation.indexOf("hash="); if (hashloc > 0) { String hash = hLocation.substring(hashloc + 5); if (hash.indexOf('&') > 0) { hash = hash.substring(0, hash.indexOf('&')); } if (hash.length() == 16) { VUID = com.juick.server.UserQueries.getUIDbyHash(sql, hash); } } // URL int hLocationQM = hLocation.indexOf('?'); if (hLocationQM > 0) { hLocation = hLocation.substring(0, hLocationQM); } int MID = 0; int responseCode = 404; SocketSubscribed sockSubscr = null; if (hLocation.equals("/") && VUID > 0) { sockSubscr = new SocketSubscribed(sock, hXRealIP, VUID); responseCode = 101; } else if (hLocation.equals("/_all")) { sockSubscr = new SocketSubscribed(sock, hXRealIP, VUID); sockSubscr.allMessages = true; responseCode = 101; } else if (hLocation.equals("/_replies")) { sockSubscr = new SocketSubscribed(sock, hXRealIP, VUID); sockSubscr.allReplies = true; responseCode = 101; } else if (hLocation.matches("^/\\d+$")) { try { MID = Integer.parseInt(hLocation.substring(1)); } catch (Exception e) { } if (MID > 0) { if (MessagesQueries.canViewThread(sql, MID, VUID)) { sockSubscr = new SocketSubscribed(sock, hXRealIP, VUID); sockSubscr.MID = MID; responseCode = 101; } else { responseCode = 403; } } } else if (hLocation.matches("^/[a-zA-Z0-9\\-]{2,16}/?$")) { String uname; if (hLocation.endsWith("/")) { uname = hLocation.substring(1, hLocation.length() - 2); } else { uname = hLocation.substring(1); } int UID = UserQueries.getUIDbyName(sql, uname); if (UID > 0) { // check access sockSubscr = new SocketSubscribed(sock, hXRealIP, VUID); sockSubscr.UID = UID; responseCode = 101; } } if (sockSubscr != null) { synchronized (Main.clients) { Main.clients.add(sockSubscr); } } // Response String outstr; if (responseCode == 101) { outstr = "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: " + calcHeaderAccept(hSecWebSocketKey) + "\r\n" + "\r\n"; } else if (responseCode == 403) { outstr = "HTTP/1.1 403 Forbidden\r\n\r\n"; } else { outstr = "HTTP/1.1 404 Not Found\r\n\r\n"; } ByteBuffer out = ByteBuffer.allocate(1024); out.put(Charset.forName("ISO-8859-1").encode(outstr)); out.flip(); sock.write(out); if (responseCode == 101) { logger.info(sock.socket().getRemoteSocketAddress().toString() + " HANDSHAKE (VUID = " + VUID + "; MID = " + MID + ")"); } else { throw new IOException(sock.socket().getRemoteSocketAddress().toString() + " " + responseCode); } } private String calcHeaderAccept(String key) { String base = key + WEBSOCKET_GUID; try { MessageDigest md = MessageDigest.getInstance("SHA-1"); return Base64.encode(md.digest(base.getBytes())); } catch (NoSuchAlgorithmException e) { logger.severe("calcHeaderAccept: " + e); } return ""; } public void wsPing(SocketChannel sock) throws Exception { ByteBuffer out = ByteBuffer.allocate(2); out.put((byte) 0x8A); // PONG FRAME out.put((byte) 0x00); // 1 byte long out.flip(); out.rewind(); sock.write(out); } public void wsTextFrame(SocketChannel sock, ByteBuffer buf) throws Exception { /* ByteBuffer out = ByteBuffer.allocate(3); out.put((byte) 0x81); // TEXT FRAME out.put((byte) 0x01); // 1 byte long out.put((byte) 0x20); // ' ' out.flip(); out.rewind(); sock.write(out); */ } public void updateSocketTS(SocketChannel sock) { synchronized (Main.clients) { for (SocketSubscribed s : Main.clients) { if (s.sock == sock) { s.tsLastData = System.currentTimeMillis(); break; } } } } }