代码修改自:https://github.com/codebutler/android-websockets

HybiParser.java

//
// HybiParser.java: draft-ietf-hybi-thewebsocketprotocol-13 parser
//
// Based on code from the faye project.
// https://github.com/faye/faye-websocket-node
// Copyright (c) 2009-2012 James Coglan
//
// Ported from Javascript to Java by Eric Butler <eric@codebutler.com>
//
// (The MIT License)
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
// LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
// WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

package com.codebutler.android_websockets;

import android.util.Log;

import java.io.*;
import java.util.Arrays;
import java.util.List;

public class HybiParser {
private static final String TAG = "HybiParser";

private WebSocketClient mClient;

private boolean mMasking = true;

private int mStage;

private boolean mFinal;
private boolean mMasked;
private int mOpcode;
private int mLengthSize;
private int mLength;
private int mMode;

private byte[] mMask = new byte[0];
private byte[] mPayload = new byte[0];

private boolean mClosed = false;

private ByteArrayOutputStream mBuffer = new ByteArrayOutputStream();

private static final int BYTE = 255;
private static final int FIN = 128;
private static final int MASK = 128;
private static final int RSV1 = 64;
private static final int RSV2 = 32;
private static final int RSV3 = 16;
private static final int OPCODE = 15;
private static final int LENGTH = 127;

private static final int MODE_TEXT = 1;
private static final int MODE_BINARY = 2;

private static final int OP_CONTINUATION = 0;
private static final int OP_TEXT = 1;
private static final int OP_BINARY = 2;
private static final int OP_CLOSE = 8;
private static final int OP_PING = 9;
private static final int OP_PONG = 10;

private static final List<Integer> OPCODES = Arrays.asList(
OP_CONTINUATION,
OP_TEXT,
OP_BINARY,
OP_CLOSE,
OP_PING,
OP_PONG
);

private static final List<Integer> FRAGMENTED_OPCODES = Arrays.asList(
OP_CONTINUATION, OP_TEXT, OP_BINARY
);

public HybiParser(WebSocketClient client) {
mClient = client;
}

private static byte[] mask(byte[] payload, byte[] mask, int offset) {
if (mask.length == 0) return payload;

for (int i = 0; i < payload.length - offset; i++) {
payload[offset + i] = (byte) (payload[offset + i] ^ mask[i % 4]);
}
return payload;
}

public void start(HappyDataInputStream stream) throws IOException {
while (true) {
if (stream.available() == -1) break;
switch (mStage) {
case 0:
parseOpcode(stream.readByte());
break;
case 1:
parseLength(stream.readByte());
break;
case 2:
parseExtendedLength(stream.readBytes(mLengthSize));
break;
case 3:
mMask = stream.readBytes(4);
mStage = 4;
break;
case 4:
mPayload = stream.readBytes(mLength);
emitFrame();
mStage = 0;
break;
}
}
mClient.getListener().onDisconnect(0, "EOF");
}

private void parseOpcode(byte data) throws ProtocolError {
boolean rsv1 = (data & RSV1) == RSV1;
boolean rsv2 = (data & RSV2) == RSV2;
boolean rsv3 = (data & RSV3) == RSV3;

if (rsv1 || rsv2 || rsv3) {
throw new ProtocolError("RSV not zero");
}

mFinal = (data & FIN) == FIN;
mOpcode = (data & OPCODE);
mMask = new byte[0];
mPayload = new byte[0];

if (!OPCODES.contains(mOpcode)) {
throw new ProtocolError("Bad opcode");
}

if (!FRAGMENTED_OPCODES.contains(mOpcode) && !mFinal) {
throw new ProtocolError("Expected non-final packet");
}

mStage = 1;
}

private void parseLength(byte data) {
mMasked = (data & MASK) == MASK;
mLength = (data & LENGTH);

if (mLength >= 0 && mLength <= 125) {
mStage = mMasked ? 3 : 4;
} else {
mLengthSize = (mLength == 126) ? 2 : 8;
mStage = 2;
}
}

private void parseExtendedLength(byte[] buffer) throws ProtocolError {
mLength = getInteger(buffer);
mStage = mMasked ? 3 : 4;
}

public byte[] frame(String data) {
return frame(data, OP_TEXT, -1);
}

public byte[] frame(byte[] data) {
return frame(data, OP_BINARY, -1);
}

private byte[] frame(byte[] data, int opcode, int errorCode) {
return frame((Object)data, opcode, errorCode);
}

private byte[] frame(String data, int opcode, int errorCode) {
return frame((Object)data, opcode, errorCode);
}

private byte[] frame(Object data, int opcode, int errorCode) {
if (mClosed) return null;

Log.d(TAG, "Creating frame for: " + data + " op: " + opcode + " err: " + errorCode);

byte[] buffer = (data instanceof String) ? decode((String) data) : (byte[]) data;
int insert = (errorCode > 0) ? 2 : 0;
int length = buffer.length + insert;
int header = (length <= 125) ? 2 : (length <= 65535 ? 4 : 10);
int offset = header + (mMasking ? 4 : 0);
int masked = mMasking ? MASK : 0;
byte[] frame = new byte[length + offset];

frame[0] = (byte) ((byte)FIN | (byte)opcode);

if (length <= 125) {
frame[1] = (byte) (masked | length);
} else if (length <= 65535) {
frame[1] = (byte) (masked | 126);
frame[2] = (byte) Math.floor(length / 256);
frame[3] = (byte) (length & BYTE);
} else {
frame[1] = (byte) (masked | 127);
frame[2] = (byte) (((int) Math.floor(length / Math.pow(2, 56))) & BYTE);
frame[3] = (byte) (((int) Math.floor(length / Math.pow(2, 48))) & BYTE);
frame[4] = (byte) (((int) Math.floor(length / Math.pow(2, 40))) & BYTE);
frame[5] = (byte) (((int) Math.floor(length / Math.pow(2, 32))) & BYTE);
frame[6] = (byte) (((int) Math.floor(length / Math.pow(2, 24))) & BYTE);
frame[7] = (byte) (((int) Math.floor(length / Math.pow(2, 16))) & BYTE);
frame[8] = (byte) (((int) Math.floor(length / Math.pow(2, 8))) & BYTE);
frame[9] = (byte) (length & BYTE);
}

if (errorCode > 0) {
frame[offset] = (byte) (((int) Math.floor(errorCode / 256)) & BYTE);
frame[offset+1] = (byte) (errorCode & BYTE);
}
System.arraycopy(buffer, 0, frame, offset + insert, buffer.length);

if (mMasking) {
byte[] mask = {
(byte) Math.floor(Math.random() * 256), (byte) Math.floor(Math.random() * 256),
(byte) Math.floor(Math.random() * 256), (byte) Math.floor(Math.random() * 256)
};
System.arraycopy(mask, 0, frame, header, mask.length);
mask(frame, mask, offset);
}

return frame;
}

public void ping(String message) {
mClient.send(frame(message, OP_PING, -1));
}

public void close(int code, String reason) {
if (mClosed) return;
mClient.send(frame(reason, OP_CLOSE, code));
mClosed = true;
}

private void emitFrame() throws IOException {
byte[] payload = mask(mPayload, mMask, 0);
int opcode = mOpcode;

if (opcode == OP_CONTINUATION) {
if (mMode == 0) {
throw new ProtocolError("Mode was not set.");
}
mBuffer.write(payload);
if (mFinal) {
byte[] message = mBuffer.toByteArray();
if (mMode == MODE_TEXT) {
mClient.getListener().onMessage(encode(message));
} else {
mClient.getListener().onMessage(message);
}
reset();
}

} else if (opcode == OP_TEXT) {
if (mFinal) {
String messageText = encode(payload);
mClient.getListener().onMessage(messageText);
} else {
mMode = MODE_TEXT;
mBuffer.write(payload);
}

} else if (opcode == OP_BINARY) {
if (mFinal) {
mClient.getListener().onMessage(payload);
} else {
mMode = MODE_BINARY;
mBuffer.write(payload);
}

} else if (opcode == OP_CLOSE) {
int code = (payload.length >= 2) ? 256 * payload[0] + payload[1] : 0;
String reason = (payload.length > 2) ? encode(slice(payload, 2)) : null;
Log.d(TAG, "Got close op! " + code + " " + reason);
mClient.getListener().onDisconnect(code, reason);

} else if (opcode == OP_PING) {
if (payload.length > 125) { throw new ProtocolError("Ping payload too large"); }
Log.d(TAG, "Sending pong!!");
mClient.sendFrame(frame(payload, OP_PONG, -1));

} else if (opcode == OP_PONG) {
String message = encode(payload);
// FIXME: Fire callback...
Log.d(TAG, "Got pong! " + message);
}
}

private void reset() {
mMode = 0;
mBuffer.reset();
}

private String encode(byte[] buffer) {
try {
return new String(buffer, "UTF-8");
} catch (UnsupportedEncodingException e) {
throw new RuntimeException(e);
}
}

private byte[] decode(String string) {
try {
return (string).getBytes("UTF-8");
} catch (UnsupportedEncodingException e) {
throw new RuntimeException(e);
}
}

private int getInteger(byte[] bytes) throws ProtocolError {
long i = byteArrayToLong(bytes, 0, bytes.length);
if (i < 0 || i > Integer.MAX_VALUE) {
throw new ProtocolError("Bad integer: " + i);
}
return (int) i;
}

private byte[] slice(byte[] array, int start) {
return Arrays.copyOfRange(array, start, array.length);
}

public static class ProtocolError extends IOException {
public ProtocolError(String detailMessage) {
super(detailMessage);
}
}

private static long byteArrayToLong(byte[] b, int offset, int length) {
if (b.length < length)
throw new IllegalArgumentException("length must be less than or equal to b.length");

long value = 0;
for (int i = 0; i < length; i++) {
int shift = (length - 1 - i) * 8;
value += (b[i + offset] & 0x000000FF) << shift;
}
return value;
}

public static class HappyDataInputStream extends DataInputStream {
public HappyDataInputStream(InputStream in) {
super(in);
}

public byte[] readBytes(int length) throws IOException {
byte[] buffer = new byte[length];
readFully(buffer);
return buffer;
}
}
}

WebSocketClient.java

package com.codebutler.android_websockets;

import android.os.Handler;
import android.os.HandlerThread;
import android.text.TextUtils;
import android.util.Base64;
import android.util.Log;
import org.apache.http.*;
import org.apache.http.client.HttpResponseException;
import org.apache.http.message.BasicLineParser;
import org.apache.http.message.BasicNameValuePair;

import javax.net.SocketFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509TrustManager;

import java.io.EOFException;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintWriter;
import java.net.Socket;
import java.net.URI;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import java.util.List;

public class WebSocketClient {
private static final String TAG = "WebSocketClient";

private URI mURI;
private Listener mListener;
private Socket mSocket;
private Thread mThread;
private HandlerThread mHandlerThread;
private Handler mHandler;
private List<BasicNameValuePair> mExtraHeaders;
private HybiParser mParser;

private final Object mSendLock = new Object();

private static TrustManager[] sTrustManagers;

public static void setTrustManagers(TrustManager[] tm) {
sTrustManagers = tm;
}

public WebSocketClient(URI uri, Listener listener, List<BasicNameValuePair> extraHeaders) {
mURI = uri;
mListener = listener;
mExtraHeaders = extraHeaders;
mParser = new HybiParser(this);

mHandlerThread = new HandlerThread("websocket-thread");
mHandlerThread.start();
mHandler = new Handler(mHandlerThread.getLooper());
}

public Listener getListener() {
return mListener;
}

public void connect() {
if (mThread != null && mThread.isAlive()) {
return;
}

mThread = new Thread(new Runnable() {
@Override
public void run() {
try {
String secret = createSecret();

int port = (mURI.getPort() != -1) ? mURI.getPort() : (mURI.getScheme().equals("wss") ? 443 : 80);

String path = TextUtils.isEmpty(mURI.getPath()) ? "/" : mURI.getPath();
if (!TextUtils.isEmpty(mURI.getQuery())) {
path += "?" + mURI.getQuery();
}

String originScheme = mURI.getScheme().equals("wss") ? "https" : "http";
URI origin = new URI(originScheme, "//" + mURI.getHost(), null);

SocketFactory factory = mURI.getScheme().equals("wss") ? getSSLSocketFactory() : SocketFactory.getDefault();
mSocket = factory.createSocket(mURI.getHost(), port);

PrintWriter out = new PrintWriter(mSocket.getOutputStream());
out.print("GET " + path + " HTTP/1.1\r\n");
out.print("Upgrade: websocket\r\n");
out.print("Connection: Upgrade\r\n");
out.print("Host: " + mURI.getHost() + "\r\n");
out.print("Origin: " + origin.toString() + "\r\n");
out.print("Sec-WebSocket-Key: " + secret + "\r\n");
out.print("Sec-WebSocket-Version: 13\r\n");
if (mExtraHeaders != null) {
for (NameValuePair pair : mExtraHeaders) {
out.print(String.format("%s: %s\r\n", pair.getName(), pair.getValue()));
}
}
out.print("\r\n");
out.flush();

HybiParser.HappyDataInputStream stream = new HybiParser.HappyDataInputStream(mSocket.getInputStream());

// Read HTTP response status line.
StatusLine statusLine = parseStatusLine(readLine(stream));
if (statusLine == null) {
throw new HttpException("Received no reply from server.");
} else if (statusLine.getStatusCode() != HttpStatus.SC_SWITCHING_PROTOCOLS) {
throw new HttpResponseException(statusLine.getStatusCode(), statusLine.getReasonPhrase());
}

// Read HTTP response headers.
String line;
boolean validated = false;

while (!TextUtils.isEmpty(line = readLine(stream))) {
Header header = parseHeader(line);
if (header.getName().equals("Sec-WebSocket-Accept")) {
String expected = createSecretValidation(secret);
String actual = header.getValue().trim();

if (!expected.equals(actual)) {
throw new HttpException("Bad Sec-WebSocket-Accept header value.");
}

validated = true;
}
}

if (!validated) {
throw new HttpException("No Sec-WebSocket-Accept header.");
}

mListener.onConnect();

// Now decode websocket frames.
mParser.start(stream);

} catch (EOFException ex) {
Log.d(TAG, "WebSocket EOF!", ex);
mListener.onDisconnect(0, "EOF");

} catch (SSLException ex) {
// Connection reset by peer
Log.d(TAG, "Websocket SSL error!", ex);
mListener.onDisconnect(0, "SSL");

} catch (Exception ex) {
mListener.onError(ex);
}
}
});
mThread.start();
}

public void disconnect() {
if (mSocket != null) {
mHandler.post(new Runnable() {
@Override
public void run() {
try {
mSocket.close();
mSocket = null;
} catch (IOException ex) {
Log.d(TAG, "Error while disconnecting", ex);
mListener.onError(ex);
}
}
});
}
}

public void send(String data) {
sendFrame(mParser.frame(data));
}

public void send(byte[] data) {
sendFrame(mParser.frame(data));
}

private StatusLine parseStatusLine(String line) {
if (TextUtils.isEmpty(line)) {
return null;
}
return BasicLineParser.parseStatusLine(line, new BasicLineParser());
}

private Header parseHeader(String line) {
return BasicLineParser.parseHeader(line, new BasicLineParser());
}

// Can't use BufferedReader because it buffers past the HTTP data.
private String readLine(HybiParser.HappyDataInputStream reader) throws IOException {
int readChar = reader.read();
if (readChar == -1) {
return null;
}
StringBuilder string = new StringBuilder("");
while (readChar != '\n') {
if (readChar != '\r') {
string.append((char) readChar);
}

readChar = reader.read();
if (readChar == -1) {
return null;
}
}
return string.toString();
}

private String createSecret() {
byte[] nonce = new byte[16];
for (int i = 0; i < 16; i++) {
nonce[i] = (byte) (Math.random() * 256);
}
return Base64.encodeToString(nonce, Base64.DEFAULT).trim();
}

private String createSecretValidation(String secret) {
try {
MessageDigest md = MessageDigest.getInstance("SHA-1");
md.update((secret + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11").getBytes());
return Base64.encodeToString(md.digest(), Base64.DEFAULT).trim();
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException(e);
}
}

void sendFrame(final byte[] frame) {
mHandler.post(new Runnable() {
@Override
public void run() {
try {
synchronized (mSendLock) {
if (mSocket == null) {
throw new IllegalStateException("Socket not connected");
}
OutputStream outputStream = mSocket.getOutputStream();
outputStream.write(frame);
outputStream.flush();
}
} catch (IOException e) {
mListener.onError(e);
}
}
});
}

public interface Listener {
public void onConnect();
public void onMessage(String message);
public void onMessage(byte[] data);
public void onDisconnect(int code, String reason);
public void onError(Exception error);
}

//为开发便利,此处信任所有证书,app正式发布时需使用信任机构签发的证书
private SSLSocketFactory getSSLSocketFactory() throws Exception {
/*
SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, sTrustManagers, null);
return sslContext.getSocketFactory();
*/
TrustManager[] trustAllCerts = new TrustManager[]{new X509TrustManager() {
public java.security.cert.X509Certificate[] getAcceptedIssuers() {
return new java.security.cert.X509Certificate[]{};
}


public void checkClientTrusted(X509Certificate[] chain,
String authType) throws CertificateException {
}


public void checkServerTrusted(X509Certificate[] chain,
String authType) throws CertificateException {
}
}};

// Install the all-trusting trust manager
try {
SSLContext sslContext = SSLContext.getInstance("TLS");
sslContext.init(null, trustAllCerts, new java.security.SecureRandom());
return sslContext.getSocketFactory();
} catch (Exception e) {
e.printStackTrace();
}
return null;
}
}

demo:

List<BasicNameValuePair> extraHeaders = Arrays.asList(
new BasicNameValuePair("Cookie", "session=abcd")
);
final String TAG = "WebSocketClient";
WebSocketClient client = new WebSocketClient(URI.create("wss://echo.websocket.org:443/"), new WebSocketClient.Listener() {
@Override
public void onConnect() {
Log.d(TAG, "Connected!");
}

@Override
public void onMessage(String message) {
Log.d(TAG, String.format("Got string message! %s", message));
}

@Override
public void onMessage(byte[] data) {
//Log.d(TAG, String.format("Got binary message! %s", toHexString(data)));
}

@Override
public void onDisconnect(int code, String reason) {
Log.d(TAG, String.format("Disconnected! Code: %d Reason: %s", code, reason));
}

@Override
public void onError(Exception error) {
Log.e(TAG, "Error!", error);
}

}, extraHeaders);

client.connect();
try {
Thread.sleep(500);
} catch (InterruptedException e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
// Later…
client.send("SHAKEHAND");
//client.send(new byte[] { (byte) 0xDE, (byte)0xAD, (byte)0xBE, (byte)0xEF });
client.disconnect();