#!/usr/bin/ruby
# Filename:	proxy
# Author:	David Ljung Madison <DaveSource.com>
# See License:	http://MarginalHacks.com/License/
# Description:	simple pass-thru proxy with ability to
# 	use another proxy and change outbound headers.
##################################################
require 'socket'
require 'uri'
# Needed on some systems for Mutex
require 'thread'
# Timeouts for connections
require 'timeout'

TESTING = (`hostname`.chomp == 'thinkdave') ? 1 : 0

# Only allow connections from IPs that match this (or nil for no limits)
LIMIT_CONN_TO_IP = /^127\.0\.0\.1$/

##################################################
# SETTINGS
##################################################
PROGNAME = $0.gsub(/.*\//,'')

##################################################
# Usage
##################################################
def time
	Time.now.to_i
end

def usage(*msg)
	msg.each { |m| $stderr.puts "ERROR: #{m}" }
	$stderr.puts <<-USAGE

Usage:  #{PROGNAME} [-d] <file>
  Run a proxy
  -port #    Chose a port

  -d         Set debug mode

	USAGE
	exit -1;
end


class Opt
	attr_accessor :port
	def initialize
		@port = 8080
	end
end
OPT = Opt.new

def parseArgs
	loop {
		if (arg=ARGV.shift)==nil then break
		elsif arg == '-h' then usage
		elsif arg == '-?' then usage
		elsif arg == '-port' then OPT.port = ARGV.shift.to_i

		elsif arg =~ /^-/ then usage("Unknown option [#{arg}]")
		else usage("Unknown arg [#{arg}]")
		end
	}
end

##################################################
# Fix for broken URI.parse (doesn't allow '_' in subdomains)
##################################################
module URI
	class << self
		alias origsplit split
		def split(uri)
			return origsplit(uri) unless uri.gsub!(/^([^:]+:\/\/[^\/]+)_/,'\1UNDERLINEuriSplitISbrokenUNDERLINE')
			fix = origsplit(uri)
			fix[2].gsub!(/UNDERLINEuriSplitISbrokenUNDERLINE/,'_')
			fix
		end
	end
end

# Fix for broken URIs that URI.parse complains about
def URIparse(url)
	URI.parse(url.gsub(/\|/,'%7C').gsub(/\'/,'%27').gsub(/;/,'%3B'))
end

# Option #2 - a little cleaner but more invasive to my code:
#   u = URI::Parser.new(:HOSTNAME => "(?:[a-zA-Z0-9\\-._~]|%\\h\\h)+")
#   uri = u.parse("http://www_w.example.com")


##################################################
# Utilities
##################################################
SocketTimeout = 2
def writeTimeout(sock,str,timeout)
	return false unless IO::select(nil,[sock],nil,timeout)
	sock.write(str)
	true
end

def readTimeout(sock,len,timeout)
	buff=''
	return nil if sock.eof?
	return nil unless IO::select([sock],nil,nil,timeout)
	return nil if sock.eof?
	sock.read_nonblock(len,buff)
	buff
end

def unHTML(str)
	str.gsub(/%(\d\d)/) { $1.to_i(16).chr }
end

##################################################
# Downloader
##################################################
class Headers
	def initialize(fd)
		@hdrs = Array.new
		@hdr = Hash.new
		@newline = "\n"
		begin
			loop do
				line = fd.readline
				break if line =~ /^\s*$/
				@hdrs.push(line)
				line = line.sub(/(\r?\n)$/,'')
				@newline = $1 if $1
				next unless line =~ /^(\S+):\s*(\S.*)$/
				@hdr[$1.downcase] = $2
			end
		rescue EOFError
			# Ignore (exits loop)
		end
	end
	def to_s
		@hdrs.join('')
	end
	def to_a
		@hdrs
	end
	def [](key)
		return @hdr[key.downcase]
#		@hdrs.each { |line|
#			next unless line =~ /^(\S+):\s*(\S.*)$/
#			next unless $1.downcase == key.downcase
#			return $2
#		}
#		nil
	end
	def []=(k,v)
		delete(k)
		@hdrs.push("#{k}: #{v}#{@newline}")
		@hdr[k.downcase] = v
	end
	def request
		@hdrs[0]
	end
	def response
		@hdrs[0]
	end
	def delete(what)
		@hdrs.delete_if { |line|
			if what.class==Regexp ? line.match(what) : (line==what || line=~/^#{what}:\s/i)
				key = line.split(':')[0].downcase
				@hdr.delete(key)
				true
			else
					false
			end
		}
	end
end

# Thread safe counter
class Counter
	def initialize(num=1)
		@cnt = num
		@sem = Mutex.new
	end
	def inc
		val = nil
		@sem.synchronize {
			val = @cnt
			@cnt += 1
		}
		val
	end
end

##################################################
# Proxy code
##################################################
class Proxy  
	def initialize
		@port = OPT.port
		@txn = Counter.new
		@addHdrs = Hash.new
		@proxyHost = nil
		@proxyPort = nil

#require 'base64'
#auth = Base64.encode64("newbie:passwd")
#@addHdrs['Proxy-Authorization'] = "Basic #{auth}"
#http://localproxy/addHdr=Proxy-Authorization:Basic%20#{auth}
#http://localproxy/addHdr=Proxy-Authorization:Basic%20ZG8geW91IHRoaW5rIEkgYW0gc3R1cGlkPw==
#http://localproxy/proxy=localhost:8111
# ZZZZZZZZZZZ
	end

	#########################
	# Start the proxy
	#########################
	def run
		begin
			puts "Start proxy on #{@port}"
			# Start our server to handle connections (will raise things on errors)
			@socket = TCPServer.new @port
			
			# Handle every request in another thread
			loop do
				s = @socket.accept
				Thread.new(s) { |s|
					begin
						handleRequest(s)
					rescue EOFError
						# Ignore
					rescue => e
						puts "ERROR: #{e} [#{e.class}]"
						#puts e.backtrace
					end
				}
			end
			
		# CTRL-C
		rescue Interrupt
			puts 'Saw interrupt'
			puts 
		# Ensure that we release the socket on errors
		ensure
			if @socket
				@socket.close
				puts 'Closed socket'
			end
			puts 'Exit.'
		end
	end
	
	CODES = {
		200 => 'OK',
		400 => 'Bad Request',
		401 => 'Unauthorized',
		403 => 'Forbidden',
		404 => 'Not Found',
		407	=> 'Proxy Authentication Required',
		501 => 'Not Implemented',
		503 => 'Service Unavailable',
	}

	# Direct response from proxy
	def proxyResponse(conn,num,str,extraHdrs='')
		what = CODES[num] || "Unknown"
		conn.print <<-PROXYRESPONSE
HTTP/1.1 #{num} #{what}
Server: #{PROGNAME}
Content-Type: text/html
Accept-Ranges: bytes
Connection: close
#{extraHdrs}

<h1>Proxy: #{PROGNAME}</h1>
<h2>#{num} #{what}</h2>
<p />
#{str}
		PROXYRESPONSE
		conn.close
	end

	def hdr(txn=nil)
		tt = txn ? "#{time}, #{txn}" : time
		return "[#{tt}] "
	end

	#########################
	# Handle a request
	#########################
	def handleRequest(conn)
		txn = @txn.inc
		begin
			reqLine = conn.readline
		rescue
			return
		end
		begin
			headers = reqLine
			buff = ""
		
			verb    = reqLine[/^\w+/]
			url     = reqLine[/^\w+\s+(\S+)/, 1]
			version = reqLine[/HTTP\/(1\.\d)\s*$/, 1]
			begin
				uri     = URI::parse url
			rescue => e
				return proxyResponse(conn,501,"The proxy could not understand the url [#{url}]")
			end

			# Only allow connections from localhost
			if LIMIT_CONN_TO_IP
				fam,port,*addr = conn.getpeername.unpack('nnC4')
				return proxyResponse(conn,403, "Request forbidden") unless addr.join('.').match(LIMIT_CONN_TO_IP)
			end

			# Show what got requested
			puts "#{hdr(txn)} Request: %-4s %s" % [verb,url]

			# Proxy variables
			if url =~ /^http:\/\/localproxy\/addHdr=(\S+?):(\S+)$/
				k,v = unHTML($1),unHTML($2)
				@addHdrs[k] = v
				return proxyResponse(conn,200, "Now adding header:  #{k} -> #{v}")
			end
			if url =~ /^http:\/\/localproxy\/proxy=(\S+)$/
				@proxyHost = unHTML($1)
				@proxyPort = $1 if @proxyHost.sub!(/:(\d+)$/,'')
				return proxyResponse(conn,200, "Now using proxy:  #{@proxyHost}:#{@proxyPort}")
			end

			t=Time.now

			toServer = nil

			# Not cached - get it from the originating server
			unless toServer
				begin
					destHost = @proxyHost || uri.host
					destPort = @proxyPort || uri.port || 80
					toServer = TCPSocket.new(destHost,destPort)
					get = uri.path
					get += '?'+uri.query if uri.query
					toServer.write("#{verb} #{get} HTTP/#{version}\r\n")
					#puts "#{verb} #{get} HTTP/#{version}\n"
				rescue Errno::ECONNREFUSED, SocketError => e
					puts "#{hdr(txn)} ERR: Connection refused [#{e} -> #{uri.host}]"
#puts "DONE: #{hdr(txn)} Request: %-4s %s" % [verb,url]
					return proxyResponse(conn,503, "The proxy could not connect to the host [#{uri.host}]")
				rescue => e
					puts "#{hdr(txn)} ERR? #{e}  [#{e.class}]"
					puts "  uri #{uri}"
#					puts e.backtrace
#puts "DONE: #{hdr(txn)} Request: %-4s %s" % [verb,url]
					return
				end
			end

			# Strip proxy headers
			hdrs = Headers.new(conn)
			hdrs.delete(/^proxy/i)
			hdrs.delete("Keep-Alive")
			hdrs['Connection'] = 'close'

			# Add any new headers
			@addHdrs.each { |k,v| hdrs[k] = v }

			toServer.write(hdrs.to_s)
			toServer.write("\n\n")
			# Any content to read?
			conLen = hdrs['Content-Length']
			if conLen && conLen.to_i>0
				buff = conn.read(conLen.to_i)
				toServer.write(buff)
			end

			# Now get server headers/response
			resHdrs = Headers.new(toServer)

			conn.write(resHdrs.to_s)
			conn.write("\r\n")

			loop do
				buff = readTimeout(toServer,2048,SocketTimeout)
				break unless buff
				writeTimeout(conn,buff,SocketTimeout)
				break if toServer.eof?
			end

			conn.close
			toServer.close

		rescue => e
			rl = reqLine ? "#{reqLine}" : "before ReqLine read"
			puts "#{hdr(txn)} Error: #{e} [#{rl}]"
			puts e.backtrace if TESTING
			# Ignore
		ensure
			conn.close if conn && !conn.closed?
			toServer.close if toServer && !toServer.closed?
		end
	end
end


##################################################
# Main code
##################################################
def main
	parseArgs
	
	Proxy.new.run
end
main
