#!/usr/bin/ruby
# Filename:	proxy
# Author:	David Ljung Madison <DaveSource.com>
# See License:	http://MarginalHacks.com/License/
# Description:	Selenium test that uses a builtin
# 	proxy to authorize a remote proxy and edit
# 	headers, set referers, etc...
##################################################
# Requires: chromedriver from Chrome for the Webdriver project!

require 'rubygems'

# % gem install selenium-webdriver
require 'selenium-webdriver'

# Debian/Ubuntu: sudo apt-get install libsqlite3-ruby
require 'sqlite3'

# Authorization calculations
require 'base64'

# For simple HTTP.get
require 'net/http'

# Proxy requires
require 'socket'
require 'uri'
require 'thread'
require 'timeout'

##################################################
# SETTINGS
##################################################
# Where is the *other* proxy server and how do we authenticate?
REMOTE_HOST = 'localhost'
REMOTE_PORT = 8111
REMOTE_USER = 'username'
REMOTE_PASS = 'password'

# Only allow connections from IPs that match this (or nil for no limits)
LIMIT_CONN_TO_IP = /^127\.0\.0\.1$/

# Name of the local proxy 'domain'
LOCALPROXY = 'localseleniumproxy'

PROGNAME = $0.gsub(/.*\//,'')

##################################################
# Usage
##################################################
def time
	Time.now.to_i
end

def usage(*msg)
	msg.each { |m| $stderr.puts "ERROR: #{m}" }
	$stderr.puts <<-USAGE

Usage:  #{PROGNAME} [-d] <file>
  Run a proxy
  -port #    Chose a port

  -d         Set debug mode

	USAGE
	exit -1;
end


class Opt
	attr_accessor :port
	def initialize
		@port = 8080
	end
end
OPT = Opt.new

def parseArgs
	loop {
		if (arg=ARGV.shift)==nil then break
		elsif arg == '-h' then usage
		elsif arg == '-?' then usage
		elsif arg == '-port' then OPT.port = ARGV.shift.to_i

		elsif arg =~ /^-/ then usage("Unknown option [#{arg}]")
		else usage("Unknown arg [#{arg}]")
		end
	}
end

##################################################
# Fix for broken URI.parse (doesn't allow '_' in subdomains)
##################################################
module URI
	class << self
		alias origsplit split
		def split(uri)
			return origsplit(uri) unless uri.gsub!(/^([^:]+:\/\/[^\/]+)_/,'\1UNDERLINEuriSplitISbrokenUNDERLINE')
			fix = origsplit(uri)
			fix[2].gsub!(/UNDERLINEuriSplitISbrokenUNDERLINE/,'_')
			fix
		end
	end
end

# Fix for broken URIs that URI.parse complains about
def URIparse(url)
	URI.parse(url.gsub(/\|/,'%7C').gsub(/\'/,'%27').gsub(/;/,'%3B'))
end

# Option #2 - a little cleaner but more invasive to my code:
#   u = URI::Parser.new(:HOSTNAME => "(?:[a-zA-Z0-9\\-._~]|%\\h\\h)+")
#   uri = u.parse("http://www_w.example.com")


##################################################
# Utilities
##################################################
SocketTimeout = 2
def writeTimeout(sock,str,timeout)
	return false unless IO::select(nil,[sock],nil,timeout)
	sock.write(str)
	true
end

def readTimeout(sock,len,timeout)
	buff=''
	return nil if sock.eof?
	return nil unless IO::select([sock],nil,nil,timeout)
	return nil if sock.eof?
	sock.read_nonblock(len,buff)
	buff
end

##################################################
# HTML utilities
##################################################
def unHTML(str)
	str.gsub(/%(\d\d)/) { $1.to_i(16).chr }
end

class Headers
	def initialize(fd)
		@hdrs = Array.new
		@hdr = Hash.new
		@newline = "\n"
		begin
			loop do
				line = fd.readline
				break if line =~ /^\s*$/
				@hdrs.push(line)
				line = line.sub(/(\r?\n)$/,'')
				@newline = $1 if $1
				next unless line =~ /^(\S+):\s*(\S.*)$/
				@hdr[$1.downcase] = $2
			end
		rescue EOFError
			# Ignore (exits loop)
		end
	end
	def to_s
		@hdrs.join('')
	end
	def to_a
		@hdrs
	end
	def [](key)
		return @hdr[key.downcase]
#		@hdrs.each { |line|
#			next unless line =~ /^(\S+):\s*(\S.*)$/
#			next unless $1.downcase == key.downcase
#			return $2
#		}
#		nil
	end
	def []=(k,v)
		delete(k)
		return unless v && !v.empty?
		@hdrs.push("#{k}: #{v}#{@newline}")
		@hdr[k.downcase] = v
	end
	def request
		@hdrs[0]
	end
	def response
		@hdrs[0]
	end
	def delete(what)
		@hdrs.delete_if { |line|
			if what.class==Regexp ? line.match(what) : (line==what || line=~/^#{what}:\s/i)
				key = line.split(':')[0].downcase
				@hdr.delete(key)
				true
			else
					false
			end
		}
	end
end

# Thread safe counter
class Counter
	def initialize(num=1)
		@cnt = num
		@sem = Mutex.new
	end
	def inc
		val = nil
		@sem.synchronize {
			val = @cnt
			@cnt += 1
		}
		val
	end
end

##################################################
# Proxy code
##################################################
class Proxy  
	attr_accessor :port, :proxyHost, :proxyPort, :addHdrs
	def initialize
		@port = OPT.port
		@txn = Counter.new
		@addHdrs = Hash.new
		@proxyHost = nil
		@proxyPort = nil
		@running = nil
	end

	#########################
	# Start the proxy
	#########################
	def run(attempt=0)
		portInUse = false
		begin
			@running = Thread.new {
				begin
					puts "Start proxy on #{@port}"
					# Start our server to handle connections (will raise things on errors)
					@socket = TCPServer.new @port
			
					# Handle every request in another thread
					loop do
						s = @socket.accept
						Thread.new(s) { |s|
							begin
								handleRequest(s)
							rescue EOFError
								# Ignore
							rescue => e
								puts "ERROR: #{e} [#{e.class}]"
								#puts e.backtrace
							end
						}
					end
				rescue Errno::EADDRINUSE
					portInUse = true
				rescue => e
					puts "YO : #{e.class}"
					puts e.backtrace
				# Ensure that we release the socket on errors
				ensure
					if @socket
						@socket.close
						puts 'Closed socket'
					end
				end
			}
			
		end
		return unless portInUse || attempt>50
		puts "Port in use - attempting port #{@port}"
		@port += 1
		run(attempt+1)
	end

	def quit
		@running.kill
	end
	
	CODES = {
		200 => 'OK',
		400 => 'Bad Request',
		401 => 'Unauthorized',
		403 => 'Forbidden',
		404 => 'Not Found',
		407	=> 'Proxy Authentication Required',
		501 => 'Not Implemented',
		503 => 'Service Unavailable',
	}

	# Direct response from proxy
	def proxyResponse(conn,num,str,extraHdrs='')
		what = CODES[num] || "Unknown"
		conn.print <<-PROXYRESPONSE
HTTP/1.1 #{num} #{what}
Server: #{PROGNAME}
Content-Type: text/html
Accept-Ranges: bytes
Connection: close
#{extraHdrs}

<h1>Proxy: #{PROGNAME}</h1>
<h2>#{num} #{what}</h2>
<p />
#{str}
		PROXYRESPONSE
		conn.close
	end

	def hdr(txn=nil)
		tt = txn ? "#{time}, #{txn}" : time
		return "[#{tt}] "
	end

	#########################
	# Handle a request
	#########################
	def handleRequest(conn)
		txn = @txn.inc
		begin
			reqLine = conn.readline
		rescue
			return
		end
		begin
			headers = reqLine
			buff = ""
		
			verb    = reqLine[/^\w+/]
			url     = reqLine[/^\w+\s+(\S+)/, 1]
			version = reqLine[/HTTP\/(1\.\d)\s*$/, 1]
			begin
				uri     = URI::parse url
			rescue => e
				return proxyResponse(conn,501,"The proxy could not understand the url [#{url}]")
			end

			# Only allow connections from localhost
			if LIMIT_CONN_TO_IP
				fam,port,*addr = conn.getpeername.unpack('nnC4')
				return proxyResponse(conn,403, "Request forbidden") unless addr.join('.').match(LIMIT_CONN_TO_IP)
			end

			# Show what got requested
			puts "#{hdr(txn)} Request: %-4s %s" % [verb,url]

			# Proxy variables
			if url =~ /^http:\/\/#{LOCALPROXY}\/addHdr=(\S+?):(\S+)$/
				k,v = unHTML($1),unHTML($2)
				@addHdrs[k] = v
				return proxyResponse(conn,200, "Now adding header:  #{k} -> #{v}")
			end
			if url =~ /^http:\/\/#{LOCALPROXY}\/proxy=(\S+)$/
				@proxyHost = unHTML($1)
				@proxyPort = $1 if @proxyHost.sub!(/:(\d+)$/,'')
				return proxyResponse(conn,200, "Now using proxy:  #{@proxyHost}:#{@proxyPort}")
			end
			# Anything else at LOCALPROXY, just ignore (such as 'favicon.ico')
			if url =~ /^http:\/\/#{LOCALPROXY}\//
				return proxyResponse(conn,200, "Thanks, we'll just ignore that")
			end

			t=Time.now

			toServer = nil

			# Not cached - get it from the originating server
			unless toServer
				begin
					destHost = @proxyHost || uri.host
					destPort = @proxyPort || uri.port || 80
					toServer = TCPSocket.new(destHost,destPort)
					get = uri.path
					get += '?'+uri.query if uri.query
					toServer.write("#{verb} #{get} HTTP/#{version}\r\n")
					#puts "#{verb} #{get} HTTP/#{version}\n"
				rescue Errno::ECONNREFUSED, SocketError => e
					puts "#{hdr(txn)} ERR: Connection refused [#{e} -> #{uri.host}]"
#puts "DONE: #{hdr(txn)} Request: %-4s %s" % [verb,url]
					return proxyResponse(conn,503, "The proxy could not connect to the host [#{uri.host}]")
				rescue => e
					puts "#{hdr(txn)} ERR? #{e}  [#{e.class}]"
					puts "  uri #{uri}"
#					puts e.backtrace
#puts "DONE: #{hdr(txn)} Request: %-4s %s" % [verb,url]
					return
				end
			end

			# Strip proxy headers
			hdrs = Headers.new(conn)
			hdrs.delete(/^proxy/i)
			hdrs.delete("Keep-Alive")
			hdrs['Connection'] = 'close'

			# Add any new headers
			@addHdrs.each { |k,v| hdrs[k] = v }

			toServer.write(hdrs.to_s)
			toServer.write("\n\n")
			# Any content to read?
			conLen = hdrs['Content-Length']
			if conLen && conLen.to_i>0
				buff = conn.read(conLen.to_i)
				toServer.write(buff)
			end

			# Now get server headers/response
			resHdrs = Headers.new(toServer)

			conn.write(resHdrs.to_s)
			conn.write("\r\n")

			loop do
				buff = readTimeout(toServer,2048,SocketTimeout)
				break unless buff
				writeTimeout(conn,buff,SocketTimeout)
				break if toServer.eof?
			end

			conn.close
			toServer.close

		rescue => e
			rl = reqLine ? "#{reqLine}" : "before ReqLine read"
			puts "#{hdr(txn)} Error: #{e} [#{rl}]"
			puts e.backtrace
			# Ignore
		ensure
			conn.close if conn && !conn.closed?
			toServer.close if toServer && !toServer.closed?
		end
	end
end

##################################################
# Main code
##################################################
def main
	parseArgs

	#########################
	# Setup local Selenium rewrite proxy
	#########################
	proxyObj = Proxy.new
	# And point it to the replay proxy
	if REMOTE_HOST
		proxyObj.proxyHost = REMOTE_HOST
		proxyObj.proxyPort = REMOTE_PORT || 8080
		auth = Base64.encode64("#{REMOTE_USER}:#{REMOTE_PASS}")
		proxyObj.addHdrs['Proxy-Authorization'] = "Basic #{auth}"
	end
	proxyObj.run
	proxyAddr = "localhost:#{proxyObj.port}"

	# Example of how we'd do this through a web request instead
	# of accessing the proxyObj directly
#	Net::HTTP::get_response("http://#{LOCALPROXY}/proxy=#{REMOTE_HOST}:#{REMOTE_PORT}")
#	auth = Base64.encode64("#{REMOTE_USER}:#{REMOTE_PASS}")
#	Net::HTTP::get_response("http://#{LOCALPROXY}/addHdr=Proxy-Authorization:Basic%20#{auth}")

	#########################
	# Setup selenium
	#########################
	# Point selenium to the local selenium proxy
	proxySel = Selenium::WebDriver::Proxy.new(:http => proxyAddr)
	caps = Selenium::WebDriver::Remote::Capabilities.chrome(:proxy => proxySel)
	driver = Selenium::WebDriver.for :chrome , :desired_capabilities => caps


	#########################
	# Now open a test page as an example
	#########################
	# We can set a referer if we like
	proxyObj.addHdrs['Referer'] = 'http://GetDave.com/'
	driver.get('http://MarginalHacks.com/')

	#element = driver.find_element(:name, 'q')
	#element.send_keys "Hello WebDriver!"
	#element.submit
	#puts "GOT TITLE: #{driver.title}"

	#########################
	# Shutdown
	#########################
	proxyObj.quit
	driver.quit
end
main

