208 lines
6.5 KiB
Lua
208 lines
6.5 KiB
Lua
-- load our favourite library
|
|
local socket = require"socket"
|
|
|
|
-- creates a new set data structure
|
|
function newset(a)
|
|
local reverse = {}
|
|
local set = {}
|
|
return setmetatable(set, {__index = {
|
|
insert = function(set, value)
|
|
if not reverse[value] then
|
|
table.insert(set, value)
|
|
reverse[value] = table.getn(set)
|
|
end
|
|
end,
|
|
remove = function(set, value)
|
|
local index = reverse[value]
|
|
if index then
|
|
reverse[value] = nil
|
|
local top = table.remove(set)
|
|
if top ~= value then
|
|
reverse[top] = index
|
|
set[index] = top
|
|
end
|
|
end
|
|
end
|
|
}})
|
|
end
|
|
|
|
-- timeout before an inactive thread is kicked
|
|
local TIMEOUT = 10
|
|
-- set of connections waiting to receive data
|
|
local receiving = newset(1)
|
|
-- set of sockets waiting to send data
|
|
local sending = newset()
|
|
-- context for connections and servers
|
|
local context = {}
|
|
|
|
function wait(who, what)
|
|
if what == "input" then receiving:insert(who)
|
|
else sending:insert(who) end
|
|
context[who].last = socket.gettime()
|
|
coroutine.yield()
|
|
end
|
|
|
|
-- initializes the forward server
|
|
function init()
|
|
if table.getn(arg) < 1 then
|
|
print("Usage")
|
|
print(" lua forward.lua <iport:ohost:oport> ...")
|
|
os.exit(1)
|
|
end
|
|
-- for each tunnel, start a new server socket
|
|
for i, v in ipairs(arg) do
|
|
-- capture forwarding parameters
|
|
local iport, ohost, oport =
|
|
socket.skip(2, string.find(v, "([^:]+):([^:]+):([^:]+)"))
|
|
assert(iport, "invalid arguments")
|
|
-- create our server socket
|
|
local server = assert(socket.bind("*", iport))
|
|
server:settimeout(0) -- we don't want to be killed by bad luck
|
|
-- make sure server is tested for readability
|
|
receiving:insert(server)
|
|
-- add server context
|
|
context[server] = {
|
|
thread = coroutine.create(accept),
|
|
ohost = ohost,
|
|
oport = oport
|
|
}
|
|
end
|
|
end
|
|
|
|
-- starts a connection in a non-blocking way
|
|
function connect(who, host, port)
|
|
who:settimeout(0)
|
|
local ret, err = who:connect(host, port)
|
|
if not ret and err == "timeout" then
|
|
wait(who, "output")
|
|
ret, err = who:connect(host, port)
|
|
if not ret and err ~= "already connected" then
|
|
kick(context[who].peer)
|
|
kick(who)
|
|
return
|
|
end
|
|
end
|
|
return forward(who)
|
|
end
|
|
|
|
-- gets rid of a client
|
|
function kick(who)
|
|
if who then
|
|
sending:remove(who)
|
|
receiving:remove(who)
|
|
who:close()
|
|
context[who] = nil
|
|
end
|
|
end
|
|
|
|
-- loops accepting connections and creating new threads to deal with them
|
|
function accept(server)
|
|
while true do
|
|
-- accept a new connection and start a new coroutine to deal with it
|
|
local client = server:accept()
|
|
if client then
|
|
-- create contexts for client and peer.
|
|
local peer, err = socket.tcp()
|
|
if peer then
|
|
context[client] = {
|
|
last = socket.gettime(),
|
|
-- client goes straight to forwarding loop
|
|
thread = coroutine.create(forward),
|
|
peer = peer,
|
|
}
|
|
context[peer] = {
|
|
last = socket.gettime(),
|
|
peer = client,
|
|
-- peer first tries to connect to forwarding address
|
|
thread = coroutine.create(connect),
|
|
last = socket.gettime()
|
|
}
|
|
-- resume peer and client so they can do their thing
|
|
local ohost = context[server].ohost
|
|
local oport = context[server].oport
|
|
coroutine.resume(context[peer].thread, peer, ohost, oport)
|
|
coroutine.resume(context[client].thread, client)
|
|
else
|
|
print(err)
|
|
client:close()
|
|
end
|
|
end
|
|
-- tell scheduler we are done for now
|
|
wait(server, "input")
|
|
end
|
|
end
|
|
|
|
-- forwards all data arriving to the appropriate peer
|
|
function forward(who)
|
|
who:settimeout(0)
|
|
while true do
|
|
-- wait until we have something to read
|
|
wait(who, "input")
|
|
-- try to read as much as possible
|
|
local data, rec_err, partial = who:receive("*a")
|
|
-- if we had an error other than timeout, abort
|
|
if rec_err and rec_err ~= "timeout" then return kick(who) end
|
|
-- if we got a timeout, we probably have partial results to send
|
|
data = data or partial
|
|
-- forward what we got right away
|
|
local peer = context[who].peer
|
|
while true do
|
|
-- tell scheduler we need to wait until we can send something
|
|
wait(who, "output")
|
|
local ret, snd_err
|
|
local start = 0
|
|
ret, snd_err, start = peer:send(data, start+1)
|
|
if ret then break
|
|
elseif snd_err ~= "timeout" then return kick(who) end
|
|
end
|
|
-- if we are done receiving, we are done
|
|
if not rec_err then
|
|
kick(who)
|
|
kick(peer)
|
|
break
|
|
end
|
|
end
|
|
end
|
|
|
|
-- loop waiting until something happens, restarting the thread to deal with
|
|
-- what happened, and routing it to wait until something else happens
|
|
function go()
|
|
while true do
|
|
-- check which sockets are interesting and act on them
|
|
readable, writable = socket.select(receiving, sending)
|
|
-- for all readable connections, resume its thread
|
|
for _, who in ipairs(readable) do
|
|
if context[who] then
|
|
receiving:remove(who)
|
|
coroutine.resume(context[who].thread, who)
|
|
end
|
|
end
|
|
-- for all writable connections, do the same
|
|
for _, who in ipairs(writable) do
|
|
if context[who] then
|
|
sending:remove(who)
|
|
coroutine.resume(context[who].thread, who)
|
|
end
|
|
end
|
|
-- put all inactive threads in death row
|
|
local now = socket.gettime()
|
|
local deathrow
|
|
for who, data in pairs(context) do
|
|
if data.peer then
|
|
if now - data.last > TIMEOUT then
|
|
-- only create table if at least one is doomed
|
|
deathrow = deathrow or {}
|
|
deathrow[who] = true
|
|
end
|
|
end
|
|
end
|
|
-- finally kick everyone in deathrow
|
|
if deathrow then
|
|
for who in pairs(deathrow) do kick(who) end
|
|
end
|
|
end
|
|
end
|
|
|
|
init()
|
|
go()
|