Wednesday 5 March 2014

Java Socket Example: Proxy server

This tool can be used for intercept services request and redirect or return error message for specified sources. Basically, reading and writing in same socket in java.


import org.apache.log4j.Logger;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketTimeoutException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public class ProxyServerTest extends Thread {

    final static Logger logger = Logger.getLogger(ProxyServerTest.class);

    private ServerSocket serverSocket;
    private String wsServer;
    private int wsServerPort;
    private List ipList = new ArrayList();


    public ProxyServerTest(String hostName, int port, int timeout, String wsServer, int wsServerPort, List ipList) throws IOException {
        InetAddress ia = InetAddress.getByName(hostName);
        serverSocket = new ServerSocket(port, 50, ia);
        serverSocket.setSoTimeout(timeout);
        this.wsServer = wsServer;
        this.wsServerPort = wsServerPort;
        this.ipList = ipList;
    }

    public void run() {
        logger.info("Filter source ips: " + ipList);
        while (true) {
            try {
                logger.info("Waiting for client " + " LocalSocketAddress: " + serverSocket.getLocalSocketAddress()
                        + ", HostAddress: " + serverSocket.getInetAddress().getHostAddress() +
                        ", HostName: " + serverSocket.getInetAddress().getHostName() + " on port " + serverSocket.getLocalPort() + "...");
                Socket clientSocket = serverSocket.accept();

                logger.info("Just connected to " + clientSocket.getRemoteSocketAddress());

                String hostName = clientSocket.getInetAddress().getHostName();
                String hostAddress = clientSocket.getInetAddress().getHostAddress();

                logger.info("Filer IPs" + ipList + ", hostName:" + hostName + ", hostAddress:" + hostAddress);

                if (ipList.contains(hostAddress) || ipList.contains(hostName)) {
                    /** Return error message to client */
                    logger.info("Return error message to client");
                    ClientReadMessageThread cht = new ClientReadMessageThread(clientSocket);
                    Thread t = new Thread(cht);
                    t.start();
                    ClientSendMessageThread csmt = new ClientSendMessageThread(clientSocket);
                    Thread t2 = new Thread(csmt);
                    t2.start();
                    try {
                        t.join(1000);
                        t2.join(1000);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                        logger.error(e);
                    }

                } else {
                    /** Redirect to sighting WebService */
                    logger.info("Forwarding request to sighting web service :" + wsServer + ":" + wsServerPort);
                    Socket wsSocket = new Socket(wsServer, wsServerPort);

                    wsSocket.setKeepAlive(true);
                    clientSocket.setKeepAlive(true);
                    logger.info("isConnected wsSocket:" + wsSocket.isConnected());

                    OutputStream os = wsSocket.getOutputStream();
                    InputStream is = clientSocket.getInputStream();

                    Thread clientToWS = new Thread(new ClientReader(is, os, "ID1:" + clientSocket.getInetAddress().getHostName() + " to " + wsSocket.getInetAddress().getHostName()));
                    clientToWS.start();

                    InputStream wsIS = wsSocket.getInputStream();
                    OutputStream csOS = clientSocket.getOutputStream();

                    Thread wsToClient = new Thread(new ClientReader(wsIS, csOS, "ID2:" + wsSocket.getInetAddress().getHostName() + " to " + clientSocket.getInetAddress().getHostName()));
                    wsToClient.start();

                    try {
                        clientToWS.join(1000);
                        wsToClient.join(1000);
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                        logger.error(e);
                    }
                }

                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                    logger.error(e);
                }
                logger.info("Close socket: " + clientSocket.getInetAddress());
                clientSocket.close();
            } catch (SocketTimeoutException s) {
                logger.error("Socket timed out!");
                logger.error(s);
                break;
            } catch (IOException e) {
                e.printStackTrace();
                logger.error(e);
                break;
            }
        }
    }

    public static class ClientReader implements Runnable {
        private InputStream is;
        private OutputStream os;
        private String id;

        public ClientReader(InputStream i, OutputStream o, String id) {
            this.is = i;
            this.os = o;
            this.id = id;
        }

        public void run() {
            try {
                logger.info(id + " start copying.");
                int x;
                x = is.read();
                while (x != -1) {
                    //    System.out.println(id);
                    os.write(x);
                    os.flush();
                    x = is.read();
                }
                logger.info(id + " data copied. x:" + x);
            } catch (IOException e) {
                e.printStackTrace();
                logger.error(e);
            }
        }
    }

    public static void main(String[] args) {
        try {
            System.out.println("Usage: HOST_NAME PORT_NO TIME_OUT WS_HOST_NAME WS_PORT_NO COMMA_SEPARATED_IP_ADDRESS_TO_SEND_ERROR");

            String hostName = String.valueOf(args[0]);
            int port = Integer.parseInt(args[1]);
            int timeout = Integer.parseInt(args[2]);
            String wsHostName = String.valueOf(args[3]);
            int wsPort = Integer.parseInt(args[4]);
            String ipAddressToIgnore = String.valueOf(args[5]);
            logger.info("hostName:" + hostName);
            logger.info("port:" + port);
            logger.info("timeout:" + timeout);
            logger.info("wsHostName:" + wsHostName);
            logger.info("wsPort:" + wsPort);
            logger.info("ipAddressToIgnore:" + ipAddressToIgnore);
            List ipList = Arrays.asList(ipAddressToIgnore.split(","));
            Thread t = new ProxyServerTest(hostName, port, timeout, wsHostName, wsPort, ipList);
            t.start();

        } catch (Exception e) {
            e.printStackTrace();
            logger.error(e);
            System.out.println("Usage: HOST_NAME PORT_NO TIME_OUT WS_HOST_NAME WS_PORT_NO COMMA_SEPERATED_IP_ADDRESS_TO_SEND_ERROR");
        }
    }
}





import org.apache.log4j.Logger;

import java.io.IOException;
import java.io.PrintWriter;
import java.net.Socket;

public class ClientSendMessageThread implements Runnable {

    final static Logger logger = Logger.getLogger(ClientSendMessageThread.class);
    private Socket clientSocket;

    public ClientSendMessageThread(Socket clientSocket) {
        this.clientSocket = clientSocket;
    }

    public void run() {
        try {
            String errorResponse = "soap:ServerObject \"org.mule.transport.NullPayload\" not of correct type. It must be of type \"java.util.List\"";
            logger.info("Sending error response:" + clientSocket.getLocalSocketAddress());
            logger.info("Error response:" + '\n' + errorResponse);

            PrintWriter pw = new PrintWriter(clientSocket.getOutputStream());
            /** Set header details */
            pw.println("HTTP/1.1 500 Internal Server Error");
            pw.println("Content-Type: text/xml");
            pw.println("http.status: 500");
            pw.println("http.method: POST");
            pw.println("Content-Length: " + errorResponse.length());
            pw.print("\r\n");
            pw.println(errorResponse);
            pw.flush();
            // pw.close();
            logger.info("Response send to client:" + clientSocket.getLocalSocketAddress());
        } catch (IOException e) {
            e.printStackTrace();
            logger.error(e);
        }

    }
}




import org.apache.log4j.Logger;

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.net.Socket;
import java.net.SocketException;

public class ClientReadMessageThread implements Runnable {

    final static Logger logger = Logger.getLogger(ClientReadMessageThread.class);
    private Socket clientSocket;

    public ClientReadMessageThread(Socket clientSocket) {
        this.clientSocket = clientSocket;
    }

    public void run() {
        try {
            logger.info("Reading data :" + clientSocket.getLocalSocketAddress());
            BufferedReader in = new BufferedReader(new InputStreamReader(clientSocket.getInputStream()));
            String tmp;
            StringBuffer inputLine = new StringBuffer();
            while ((tmp = in.readLine()) != null) {
                inputLine.append(tmp);
                logger.info(tmp);
            }
            logger.info("Finished reading data :" + clientSocket.getLocalSocketAddress());
        } catch (SocketException se) {
            // se.printStackTrace();
            logger.warn(se);
        } catch (IOException e) {
            e.printStackTrace();
            logger.error(e);
        }
    }
}