Changed receive function. Now uniform with all other functions. Returns nil

on error, return partial result in the end.

http.lua rewritten.
This commit is contained in:
Diego Nehab 2004-03-21 07:50:15 +00:00
parent 2a14ac4fe4
commit 4919a83d22
9 changed files with 316 additions and 651 deletions

4
TODO
View File

@ -19,6 +19,10 @@
* Separar as classes em arquivos * Separar as classes em arquivos
* Retorno de sendto em datagram sockets pode ser refused * Retorno de sendto em datagram sockets pode ser refused
change mime.eol to output marker on detection of first candidate, instead
of on the second. that way it works in one pass for strings that end with
one candidate.
colocar um userdata com gc metamethod pra chamar sock_close (WSAClose); colocar um userdata com gc metamethod pra chamar sock_close (WSAClose);
sources ans sinks are always simple in http and ftp and smtp sources ans sinks are always simple in http and ftp and smtp
unify backbone of smtp and ftp unify backbone of smtp and ftp

View File

@ -13,9 +13,9 @@
/*=========================================================================*\ /*=========================================================================*\
* Internal function prototypes * Internal function prototypes
\*=========================================================================*/ \*=========================================================================*/
static int recvraw(lua_State *L, p_buf buf, size_t wanted); static int recvraw(p_buf buf, size_t wanted, luaL_Buffer *b);
static int recvline(lua_State *L, p_buf buf); static int recvline(p_buf buf, luaL_Buffer *b);
static int recvall(lua_State *L, p_buf buf); static int recvall(p_buf buf, luaL_Buffer *b);
static int buf_get(p_buf buf, const char **data, size_t *count); static int buf_get(p_buf buf, const char **data, size_t *count);
static void buf_skip(p_buf buf, size_t count); static void buf_skip(p_buf buf, size_t count);
static int sendraw(p_buf buf, const char *data, size_t count, size_t *sent); static int sendraw(p_buf buf, const char *data, size_t count, size_t *sent);
@ -73,42 +73,34 @@ int buf_meth_send(lua_State *L, p_buf buf)
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
int buf_meth_receive(lua_State *L, p_buf buf) int buf_meth_receive(lua_State *L, p_buf buf)
{ {
int top = lua_gettop(L); int err = IO_DONE, top = lua_gettop(L);
int arg, err = IO_DONE;
p_tm tm = buf->tm; p_tm tm = buf->tm;
luaL_Buffer b;
luaL_buffinit(L, &b);
tm_markstart(tm); tm_markstart(tm);
/* push default pattern if need be */
if (top < 2) {
lua_pushstring(L, "*l");
top++;
}
/* make sure we have enough stack space for all returns */
luaL_checkstack(L, top+LUA_MINSTACK, "too many arguments");
/* receive all patterns */ /* receive all patterns */
for (arg = 2; arg <= top && err == IO_DONE; arg++) { if (!lua_isnumber(L, 2)) {
if (!lua_isnumber(L, arg)) { static const char *patternnames[] = {"*l", "*a", NULL};
static const char *patternnames[] = {"*l", "*a", NULL}; const char *pattern = luaL_optstring(L, 2, "*l");
const char *pattern = lua_isnil(L, arg) ? /* get next pattern */
"*l" : luaL_checkstring(L, arg); int p = luaL_findstring(pattern, patternnames);
/* get next pattern */ if (p == 0) err = recvline(buf, &b);
switch (luaL_findstring(pattern, patternnames)) { else if (p == 1) err = recvall(buf, &b);
case 0: /* line pattern */ else luaL_argcheck(L, 0, 2, "invalid receive pattern");
err = recvline(L, buf); break;
case 1: /* until closed pattern */
err = recvall(L, buf);
if (err == IO_CLOSED) err = IO_DONE;
break;
default: /* else it is an error */
luaL_argcheck(L, 0, arg, "invalid receive pattern");
break;
}
/* get a fixed number of bytes */ /* get a fixed number of bytes */
} else err = recvraw(L, buf, (size_t) lua_tonumber(L, arg)); } else err = recvraw(buf, (size_t) lua_tonumber(L, 2), &b);
/* check if there was an error */
if (err != IO_DONE) {
luaL_pushresult(&b);
io_pusherror(L, err);
lua_pushvalue(L, -2);
lua_pushnil(L);
lua_replace(L, -4);
} else {
luaL_pushresult(&b);
lua_pushnil(L);
lua_pushnil(L);
} }
/* push nil for each pattern after an error */
for ( ; arg <= top; arg++) lua_pushnil(L);
/* last return is an error code */
io_pusherror(L, err);
#ifdef LUASOCKET_DEBUG #ifdef LUASOCKET_DEBUG
/* push time elapsed during operation as the last return value */ /* push time elapsed during operation as the last return value */
lua_pushnumber(L, (tm_gettime() - tm_getstart(tm))/1000.0); lua_pushnumber(L, (tm_gettime() - tm_getstart(tm))/1000.0);
@ -150,21 +142,18 @@ int sendraw(p_buf buf, const char *data, size_t count, size_t *sent)
* Reads a fixed number of bytes (buffered) * Reads a fixed number of bytes (buffered)
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static static
int recvraw(lua_State *L, p_buf buf, size_t wanted) int recvraw(p_buf buf, size_t wanted, luaL_Buffer *b)
{ {
int err = IO_DONE; int err = IO_DONE;
size_t total = 0; size_t total = 0;
luaL_Buffer b;
luaL_buffinit(L, &b);
while (total < wanted && (err == IO_DONE || err == IO_RETRY)) { while (total < wanted && (err == IO_DONE || err == IO_RETRY)) {
size_t count; const char *data; size_t count; const char *data;
err = buf_get(buf, &data, &count); err = buf_get(buf, &data, &count);
count = MIN(count, wanted - total); count = MIN(count, wanted - total);
luaL_addlstring(&b, data, count); luaL_addlstring(b, data, count);
buf_skip(buf, count); buf_skip(buf, count);
total += count; total += count;
} }
luaL_pushresult(&b);
return err; return err;
} }
@ -172,19 +161,17 @@ int recvraw(lua_State *L, p_buf buf, size_t wanted)
* Reads everything until the connection is closed (buffered) * Reads everything until the connection is closed (buffered)
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static static
int recvall(lua_State *L, p_buf buf) int recvall(p_buf buf, luaL_Buffer *b)
{ {
int err = IO_DONE; int err = IO_DONE;
luaL_Buffer b;
luaL_buffinit(L, &b);
while (err == IO_DONE || err == IO_RETRY) { while (err == IO_DONE || err == IO_RETRY) {
const char *data; size_t count; const char *data; size_t count;
err = buf_get(buf, &data, &count); err = buf_get(buf, &data, &count);
luaL_addlstring(&b, data, count); luaL_addlstring(b, data, count);
buf_skip(buf, count); buf_skip(buf, count);
} }
luaL_pushresult(&b); if (err == IO_CLOSED) return IO_DONE;
return err; else return err;
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
@ -192,18 +179,16 @@ int recvall(lua_State *L, p_buf buf)
* are not returned by the function and are discarded from the buffer * are not returned by the function and are discarded from the buffer
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
static static
int recvline(lua_State *L, p_buf buf) int recvline(p_buf buf, luaL_Buffer *b)
{ {
int err = IO_DONE; int err = IO_DONE;
luaL_Buffer b;
luaL_buffinit(L, &b);
while (err == IO_DONE || err == IO_RETRY) { while (err == IO_DONE || err == IO_RETRY) {
size_t count, pos; const char *data; size_t count, pos; const char *data;
err = buf_get(buf, &data, &count); err = buf_get(buf, &data, &count);
pos = 0; pos = 0;
while (pos < count && data[pos] != '\n') { while (pos < count && data[pos] != '\n') {
/* we ignore all \r's */ /* we ignore all \r's */
if (data[pos] != '\r') luaL_putchar(&b, data[pos]); if (data[pos] != '\r') luaL_putchar(b, data[pos]);
pos++; pos++;
} }
if (pos < count) { /* found '\n' */ if (pos < count) { /* found '\n' */
@ -212,7 +197,6 @@ int recvline(lua_State *L, p_buf buf)
} else /* reached the end of the buffer */ } else /* reached the end of the buffer */
buf_skip(buf, pos); buf_skip(buf, pos);
} }
luaL_pushresult(&b);
return err; return err;
} }

View File

@ -39,321 +39,146 @@ local function third(a, b, c)
return c return c
end end
----------------------------------------------------------------------------- local function shift(a, b, c, d)
-- Tries to get a pattern from the server and closes socket on error return c, d
-- sock: socket connected to the server
-- pattern: pattern to receive
-- Returns
-- received pattern on success
-- nil followed by error message on error
-----------------------------------------------------------------------------
local function try_receiving(sock, pattern)
local data, err = sock:receive(pattern)
if not data then sock:close() end
--print(data)
return data, err
end end
----------------------------------------------------------------------------- -- resquest_p forward declaration
-- Tries to send data to the server and closes socket on error local request_p
-- sock: socket connected to the server
-- ...: data to send
-- Returns
-- err: error message if any, nil if successfull
-----------------------------------------------------------------------------
local function try_sending(sock, ...)
local sent, err = sock:send(unpack(arg))
if not sent then sock:close() end
--io.write(unpack(arg))
return err
end
-----------------------------------------------------------------------------
-- Receive server reply messages, parsing for status code
-- Input
-- sock: socket connected to the server
-- Returns
-- code: server status code or nil if error
-- line: full HTTP status line
-- err: error message if any
-----------------------------------------------------------------------------
local function receive_status(sock)
local line, err = try_receiving(sock)
if not err then
local code = third(string.find(line, "HTTP/%d*%.%d* (%d%d%d)"))
return tonumber(code), line
else return nil, nil, err end
end
-----------------------------------------------------------------------------
-- Receive and parse response header fields
-- Input
-- sock: socket connected to the server
-- headers: a table that might already contain headers
-- Returns
-- headers: a table with all headers fields in the form
-- {name_1 = "value_1", name_2 = "value_2" ... name_n = "value_n"}
-- all name_i are lowercase
-- nil and error message in case of error
-----------------------------------------------------------------------------
local function receive_headers(sock, headers) local function receive_headers(sock, headers)
local line, err local line, name, value
local name, value, _
headers = headers or {}
-- get first line -- get first line
line, err = try_receiving(sock) line = socket.try(sock:receive())
if err then return nil, err end
-- headers go until a blank line is found -- headers go until a blank line is found
while line ~= "" do while line ~= "" do
-- get field-name and value -- get field-name and value
_,_, name, value = string.find(line, "^(.-):%s*(.*)") name, value = shift(string.find(line, "^(.-):%s*(.*)"))
if not name or not value then assert(name and value, "malformed reponse headers")
sock:close()
return nil, "malformed reponse headers"
end
name = string.lower(name) name = string.lower(name)
-- get next line (value might be folded) -- get next line (value might be folded)
line, err = try_receiving(sock) line = socket.try(sock:receive())
if err then return nil, err end
-- unfold any folded values -- unfold any folded values
while not err and string.find(line, "^%s") do while string.find(line, "^%s") do
value = value .. line value = value .. line
line, err = try_receiving(sock) line = socket.try(sock:receive())
if err then return nil, err end
end end
-- save pair in table -- save pair in table
if headers[name] then headers[name] = headers[name] .. ", " .. value if headers[name] then headers[name] = headers[name] .. ", " .. value
else headers[name] = value end else headers[name] = value end
end end
return headers
end end
-----------------------------------------------------------------------------
-- Aborts a sink with an error message
-- Input
-- cb: callback function
-- err: error message to pass to callback
-- Returns
-- callback return or if nil err
-----------------------------------------------------------------------------
local function abort(cb, err) local function abort(cb, err)
local go, cb_err = cb(nil, err) local go, cb_err = cb(nil, err)
return cb_err or err error(cb_err or err)
end end
----------------------------------------------------------------------------- local function hand(cb, chunk)
-- Receives a chunked message body local go, cb_err = cb(chunk)
-- Input assert(go, cb_err or "aborted by callback")
-- sock: socket connected to the server end
-- headers: header set in which to include trailer headers
-- sink: response message body sink local function receive_body_bychunks(sock, sink)
-- Returns
-- nil if successfull or an error message in case of error
-----------------------------------------------------------------------------
local function receive_body_bychunks(sock, headers, sink)
local chunk, size, line, err, go
while 1 do while 1 do
-- get chunk size, skip extention -- get chunk size, skip extention
line, err = try_receiving(sock) local line, err = sock:receive()
if err then return abort(sink, err) end if err then abort(sink, err) end
size = tonumber(string.gsub(line, ";.*", ""), 16) local size = tonumber(string.gsub(line, ";.*", ""), 16)
if not size then return abort(sink, "invalid chunk size") end if not size then abort(sink, "invalid chunk size") end
-- was it the last chunk? -- was it the last chunk?
if size <= 0 then break end if size <= 0 then break end
-- get chunk -- get chunk
chunk, err = try_receiving(sock, size) local chunk, err = sock:receive(size)
if err then return abort(sink, err) end if err then abort(sink, err) end
-- pass chunk to callback -- pass chunk to callback
go, err = sink(chunk) hand(sink, chunk)
-- see if callback aborted
if not go then return err or "aborted by callback" end
-- skip CRLF on end of chunk -- skip CRLF on end of chunk
err = second(try_receiving(sock)) err = second(sock:receive())
if err then return abort(sink, err) end if err then abort(sink, err) end
end end
-- servers shouldn't send trailer headers, but who trusts them?
err = second(receive_headers(sock, headers))
if err then return abort(sink, err) end
-- let callback know we are done -- let callback know we are done
return second(sink(nil)) hand(sink, nil)
-- servers shouldn't send trailer headers, but who trusts them?
receive_headers(sock, {})
end end
-----------------------------------------------------------------------------
-- Receives a message body by content-length
-- Input
-- sock: socket connected to the server
-- length: message body length
-- sink: response message body sink
-- Returns
-- nil if successfull or an error message in case of error
-----------------------------------------------------------------------------
local function receive_body_bylength(sock, length, sink) local function receive_body_bylength(sock, length, sink)
while length > 0 do while length > 0 do
local size = math.min(BLOCKSIZE, length) local size = math.min(BLOCKSIZE, length)
local chunk, err = sock:receive(size) local chunk, err = sock:receive(size)
local go, cb_err = sink(chunk) if err then abort(sink, err) end
length = length - string.len(chunk) length = length - string.len(chunk)
-- see if callback aborted
if not go then return cb_err or "aborted by callback" end
-- see if there was an error -- see if there was an error
if err and length > 0 then return abort(sink, err) end hand(sink, chunk)
end end
return second(sink(nil)) -- let callback know we are done
hand(sink, nil)
end end
-----------------------------------------------------------------------------
-- Receives a message body until the conection is closed
-- Input
-- sock: socket connected to the server
-- sink: response message body sink
-- Returns
-- nil if successfull or an error message in case of error
-----------------------------------------------------------------------------
local function receive_body_untilclosed(sock, sink) local function receive_body_untilclosed(sock, sink)
while 1 do while true do
local chunk, err = sock:receive(BLOCKSIZE) local chunk, err, partial = sock:receive(BLOCKSIZE)
local go, cb_err = sink(chunk)
-- see if callback aborted
if not go then return cb_err or "aborted by callback" end
-- see if we are done -- see if we are done
if err == "closed" then return chunk and second(sink(nil)) end if err == "closed" then
hand(sink, partial)
break
end
hand(sink, chunk)
-- see if there was an error -- see if there was an error
if err then return abort(sink, err) end if err then abort(sink, err) end
end end
-- let callback know we are done
hand(sink, nil)
end end
----------------------------------------------------------------------------- local function receive_body(reqt, respt)
-- Receives the HTTP response body local sink = reqt.sink or ltn12.sink.null()
-- Input local headers = respt.headers
-- sock: socket connected to the server local sock = respt.tmp.sock
-- headers: response header fields
-- sink: response message body sink
-- Returns
-- nil if successfull or an error message in case of error
-----------------------------------------------------------------------------
local function receive_body(sock, headers, sink)
-- make sure sink is not fancy
sink = ltn12.sink.simplify(sink)
local te = headers["transfer-encoding"] local te = headers["transfer-encoding"]
if te and te ~= "identity" then if te and te ~= "identity" then
-- get by chunked transfer-coding of message body -- get by chunked transfer-coding of message body
return receive_body_bychunks(sock, headers, sink) receive_body_bychunks(sock, sink)
elseif tonumber(headers["content-length"]) then elseif tonumber(headers["content-length"]) then
-- get by content-length -- get by content-length
local length = tonumber(headers["content-length"]) local length = tonumber(headers["content-length"])
return receive_body_bylength(sock, length, sink) receive_body_bylength(sock, length, sink)
else else
-- get it all until connection closes -- get it all until connection closes
return receive_body_untilclosed(sock, sink) receive_body_untilclosed(sock, sink)
end end
end end
-----------------------------------------------------------------------------
-- Sends the HTTP request message body in chunks
-- Input
-- data: data connection
-- source: request message body source
-- Returns
-- nil if successfull, or an error message in case of error
-----------------------------------------------------------------------------
local function send_body_bychunks(data, source) local function send_body_bychunks(data, source)
while 1 do while true do
local chunk, cb_err = source() local chunk, err = source()
-- check if callback aborted assert(chunk or not err, err)
if not chunk then return cb_err or "aborted by callback" end if not chunk then break end
-- if we are done, send last-chunk socket.try(data:send(string.format("%X\r\n", string.len(chunk))))
if chunk == "" then return try_sending(data, "0\r\n\r\n") end socket.try(data:send(chunk, "\r\n"))
-- else send middle chunk
local err = try_sending(data,
string.format("%X\r\n", string.len(chunk)),
chunk,
"\r\n"
)
if err then return err end
end end
socket.try(data:send("0\r\n\r\n"))
end end
-----------------------------------------------------------------------------
-- Sends the HTTP request message body
-- Input
-- data: data connection
-- source: request message body source
-- Returns
-- nil if successfull, or an error message in case of error
-----------------------------------------------------------------------------
local function send_body(data, source) local function send_body(data, source)
while 1 do while true do
local chunk, cb_err = source() local chunk, err = source()
-- check if callback is done assert(chunk or not err, err)
if not chunk then return cb_err end if not chunk then break end
-- send data socket.try(data:send(chunk))
local err = try_sending(data, chunk)
if err then return err end
end end
end end
-----------------------------------------------------------------------------
-- Sends request headers
-- Input
-- sock: server socket
-- headers: table with headers to be sent
-- Returns
-- err: error message if any
-----------------------------------------------------------------------------
local function send_headers(sock, headers) local function send_headers(sock, headers)
local err
headers = headers or {}
-- send request headers -- send request headers
for i, v in headers do for i, v in pairs(headers) do
err = try_sending(sock, i .. ": " .. v .. "\r\n") socket.try(sock:send(i .. ": " .. v .. "\r\n"))
if err then return err end
end end
-- mark end of request headers -- mark end of request headers
return try_sending(sock, "\r\n") socket.try(sock:send("\r\n"))
end end
-----------------------------------------------------------------------------
-- Sends a HTTP request message through socket
-- Input
-- sock: socket connected to the server
-- method: request method to be used
-- uri: request uri
-- headers: request headers to be sent
-- source: request message body source
-- Returns
-- err: nil in case of success, error message otherwise
-----------------------------------------------------------------------------
local function send_request(sock, method, uri, headers, source)
local chunk, size, done, err
-- send request line
err = try_sending(sock, method .. " " .. uri .. " HTTP/1.1\r\n")
if err then return err end
if source and not headers["content-length"] then
headers["transfer-encoding"] = "chunked"
end
-- send request headers
err = send_headers(sock, headers)
if err then return err end
-- send request message body, if any
if source then
-- make sure source is not fancy
source = ltn12.source.simplify(source)
if headers["content-length"] then
return send_body(sock, source)
else
return send_body_bychunks(sock, source)
end
end
end
-----------------------------------------------------------------------------
-- Determines if we should read a message body from the server response
-- Input
-- reqt: a table with the original request information
-- respt: a table with the server response information
-- Returns
-- 1 if a message body should be processed, nil otherwise
-----------------------------------------------------------------------------
local function should_receive_body(reqt, respt) local function should_receive_body(reqt, respt)
if reqt.method == "HEAD" then return nil end if reqt.method == "HEAD" then return nil end
if respt.code == 204 or respt.code == 304 then return nil end if respt.code == 204 or respt.code == 304 then return nil end
@ -361,98 +186,143 @@ local function should_receive_body(reqt, respt)
return 1 return 1
end end
----------------------------------------------------------------------------- local function receive_status(reqt, respt)
-- Converts field names to lowercase and adds a few needed headers local sock = respt.tmp.sock
-- Input local status = socket.try(sock:receive())
-- headers: request header fields local code = third(string.find(status, "HTTP/%d*%.%d* (%d%d%d)"))
-- parsed: parsed request URL -- store results
-- Returns respt.code, respt.status = tonumber(code), status
-- lower: a table with the same headers, but with lowercase field names end
-----------------------------------------------------------------------------
local function fill_headers(headers, parsed) local function request_uri(reqt, respt)
local url
local parsed = respt.tmp.parsed
if not reqt.proxy then
url = {
path = parsed.path,
params = parsed.params,
query = parsed.query,
fragment = parsed.fragment
}
else url = respt.tmp.parsed end
return socket.url.build(url)
end
local function send_request(reqt, respt)
local uri = request_uri(reqt, respt)
local sock = respt.tmp.sock
local headers = respt.tmp.headers
-- send request line
socket.try(sock:send((reqt.method or "GET")
.. " " .. uri .. " HTTP/1.1\r\n"))
-- send request headers headeres
if reqt.source and not headers["content-length"] then
headers["transfer-encoding"] = "chunked"
end
send_headers(sock, headers)
-- send request message body, if any
if reqt.source then
if headers["content-length"] then send_body(sock, reqt.source)
else send_body_bychunks(sock, reqt.source) end
end
end
local function open(reqt, respt)
local parsed = respt.tmp.parsed
local proxy = reqt.proxy or PROXY
local host, port
if proxy then
local pproxy = socket.url.parse(proxy)
assert(pproxy.port and pproxy.host, "invalid proxy")
host, port = pproxy.host, pproxy.port
else
host, port = parsed.host, parsed.port
end
local sock = socket.try(socket.tcp())
-- store results
respt.tmp.sock = sock
sock:settimeout(reqt.timeout or TIMEOUT)
socket.try(sock:connect(host, port))
end
function adjust_headers(reqt, respt)
local lower = {} local lower = {}
headers = headers or {} local headers = reqt.headers or {}
-- set default headers -- set default headers
lower["user-agent"] = USERAGENT lower["user-agent"] = USERAGENT
-- override with user values -- override with user values
for i,v in headers do for i,v in headers do
lower[string.lower(i)] = v lower[string.lower(i)] = v
end end
lower["host"] = parsed.host lower["host"] = respt.tmp.parsed.host
-- this cannot be overriden -- this cannot be overriden
lower["connection"] = "close" lower["connection"] = "close"
return lower -- store results
respt.tmp.headers = lower
end end
----------------------------------------------------------------------------- function parse_url(reqt, respt)
-- Decides wether we should follow retry with authorization formation -- parse url with default fields
-- Input local parsed = socket.url.parse(reqt.url, {
-- reqt: a table with the original request information host = "",
-- parsed: parsed request URL port = PORT,
-- respt: a table with the server response information path ="/",
-- Returns scheme = "http"
-- 1 if we should retry, nil otherwise })
----------------------------------------------------------------------------- -- scheme has to be http
local function should_authorize(reqt, parsed, respt) if parsed.scheme ~= "http" then
error(string.format("unknown scheme '%s'", parsed.scheme))
end
-- explicit authentication info overrides that given by the URL
parsed.user = reqt.user or parsed.user
parsed.password = reqt.password or parsed.password
-- store results
respt.tmp.parsed = parsed
end
local function should_authorize(reqt, respt)
-- if there has been an authorization attempt, it must have failed -- if there has been an authorization attempt, it must have failed
if reqt.headers["authorization"] then return nil end if reqt.headers and reqt.headers["authorization"] then return nil end
-- if we don't have authorization information, we can't retry -- if we don't have authorization information, we can't retry
if parsed.user and parsed.password then return 1 return respt.tmp.parsed.user and respt.tmp.parsed.password
else return nil end
end end
----------------------------------------------------------------------------- local function clone(headers)
-- Returns the result of retrying a request with authorization information if not headers then return nil end
-- Input local copy = {}
-- reqt: a table with the original request information for i,v in pairs(headers) do
-- parsed: parsed request URL copy[i] = v
-- Returns end
-- respt: result of target authorization return copy
----------------------------------------------------------------------------- end
local function authorize(reqt, parsed)
reqt.headers["authorization"] = "Basic " .. local function authorize(reqt, respt)
local headers = clone(reqt.headers) or {}
local parsed = respt.tmp.parsed
headers["authorization"] = "Basic " ..
(mime.b64(parsed.user .. ":" .. parsed.password)) (mime.b64(parsed.user .. ":" .. parsed.password))
local autht = { local autht = {
nredirects = reqt.nredirects,
method = reqt.method, method = reqt.method,
url = reqt.url, url = reqt.url,
source = reqt.source, source = reqt.source,
sink = reqt.sink, sink = reqt.sink,
headers = reqt.headers, headers = headers,
timeout = reqt.timeout, timeout = reqt.timeout,
proxy = reqt.proxy, proxy = reqt.proxy,
} }
return request_cb(autht) request_p(autht, respt)
end end
-----------------------------------------------------------------------------
-- Decides wether we should follow a server redirect message
-- Input
-- reqt: a table with the original request information
-- respt: a table with the server response information
-- Returns
-- 1 if we should redirect, nil otherwise
-----------------------------------------------------------------------------
local function should_redirect(reqt, respt) local function should_redirect(reqt, respt)
return (reqt.redirect ~= false) and return (reqt.redirect ~= false) and
(respt.code == 301 or respt.code == 302) and (respt.code == 301 or respt.code == 302) and
(reqt.method == "GET" or reqt.method == "HEAD") and (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD")
not (reqt.nredirects and reqt.nredirects >= 5) and (not respt.tmp.nredirects or respt.tmp.nredirects < 5)
end end
-----------------------------------------------------------------------------
-- Returns the result of a request following a server redirect message.
-- Input
-- reqt: a table with the original request information
-- respt: response table of previous attempt
-- Returns
-- respt: result of target redirection
-----------------------------------------------------------------------------
local function redirect(reqt, respt) local function redirect(reqt, respt)
local nredirects = reqt.nredirects or 0 respt.tmp.nredirects = (respt.tmp.nredirects or 0) + 1
nredirects = nredirects + 1
local redirt = { local redirt = {
nredirects = nredirects,
method = reqt.method, method = reqt.method,
-- the RFC says the redirect URL has to be absolute, but some -- the RFC says the redirect URL has to be absolute, but some
-- servers do not respect that -- servers do not respect that
@ -463,243 +333,58 @@ local function redirect(reqt, respt)
timeout = reqt.timeout, timeout = reqt.timeout,
proxy = reqt.proxy proxy = reqt.proxy
} }
respt = request_cb(redirt) request_p(redirt, respt)
-- we pass the location header as a clue we tried to redirect -- we pass the location header as a clue we redirected
if respt.headers then respt.headers.location = redirt.url end if respt.headers then respt.headers.location = redirt.url end
return respt
end end
----------------------------------------------------------------------------- function request_p(reqt, respt)
-- Computes the request URI from the parsed request URL parse_url(reqt, respt)
-- If we are using a proxy, we use the absoluteURI format. adjust_headers(reqt, respt)
-- Otherwise, we use the abs_path format. open(reqt, respt)
-- Input send_request(reqt, respt)
-- parsed: parsed URL receive_status(reqt, respt)
-- Returns respt.headers = {}
-- uri: request URI for parsed URL receive_headers(respt.tmp.sock, respt.headers)
-----------------------------------------------------------------------------
local function request_uri(reqt, parsed)
local url
if not reqt.proxy then
url = {
path = parsed.path,
params = parsed.params,
query = parsed.query,
fragment = parsed.fragment
}
else url = parsed end
return socket.url.build(url)
end
-----------------------------------------------------------------------------
-- Builds a request table from a URL or request table
-- Input
-- url_or_request: target url or request table (a table with the fields:
-- url: the target URL
-- user: account user name
-- password: account password)
-- Returns
-- reqt: request table
-----------------------------------------------------------------------------
local function build_request(data)
local reqt = {}
if type(data) == "table" then
for i, v in data
do reqt[i] = v
end
else reqt.url = data end
return reqt
end
-----------------------------------------------------------------------------
-- Connects to a server, be it a proxy or not
-- Input
-- reqt: the request table
-- parsed: the parsed request url
-- Returns
-- sock: connection socket, or nil in case of error
-- err: error message
-----------------------------------------------------------------------------
local function try_connect(reqt, parsed)
reqt.proxy = reqt.proxy or PROXY
local host, port
if reqt.proxy then
local pproxy = socket.url.parse(reqt.proxy)
if not pproxy.port or not pproxy.host then
return nil, "invalid proxy"
end
host, port = pproxy.host, pproxy.port
else
host, port = parsed.host, parsed.port
end
local sock, ret, err
sock, err = socket.tcp()
if not sock then return nil, err end
sock:settimeout(reqt.timeout or TIMEOUT)
ret, err = sock:connect(host, port)
if not ret then
sock:close()
return nil, err
end
return sock
end
-----------------------------------------------------------------------------
-- Sends a HTTP request and retrieves the server reply using callbacks to
-- send the request body and receive the response body
-- Input
-- reqt: a table with the following fields
-- method: "GET", "PUT", "POST" etc (defaults to "GET")
-- url: target uniform resource locator
-- user, password: authentication information
-- headers: request headers to send, or nil if none
-- source: request message body source, or nil if none
-- sink: response message body sink
-- redirect: should we refrain from following a server redirect message?
-- Returns
-- respt: a table with the following fields:
-- headers: response header fields received, or nil if failed
-- status: server response status line, or nil if failed
-- code: server status code, or nil if failed
-- error: error message, or nil if successfull
-----------------------------------------------------------------------------
function request_cb(reqt)
local sock, ret
local parsed = socket.url.parse(reqt.url, {
host = "",
port = PORT,
path ="/",
scheme = "http"
})
local respt = {}
if parsed.scheme ~= "http" then
respt.error = string.format("unknown scheme '%s'", parsed.scheme)
return respt
end
-- explicit authentication info overrides that given by the URL
parsed.user = reqt.user or parsed.user
parsed.password = reqt.password or parsed.password
-- default method
reqt.method = reqt.method or "GET"
-- fill default headers
reqt.headers = fill_headers(reqt.headers, parsed)
-- try to connect to server
sock, respt.error = try_connect(reqt, parsed)
if not sock then return respt end
-- send request message
respt.error = send_request(sock, reqt.method,
request_uri(reqt, parsed), reqt.headers, reqt.source)
if respt.error then
sock:close()
return respt
end
-- get server response message
respt.code, respt.status, respt.error = receive_status(sock)
if respt.error then return respt end
-- deal with continue 100
-- servers should not send them, but some do!
if respt.code == 100 then
respt.headers, respt.error = receive_headers(sock, {})
if respt.error then return respt end
respt.code, respt.status, respt.error = receive_status(sock)
if respt.error then return respt end
end
-- receive all headers
respt.headers, respt.error = receive_headers(sock, {})
if respt.error then return respt end
-- decide what to do based on request and response parameters
if should_redirect(reqt, respt) then if should_redirect(reqt, respt) then
-- drop the body respt.tmp.sock:close()
receive_body(sock, respt.headers, ltn12.sink.null()) redirect(reqt, respt)
-- we are done with this connection elseif should_authorize(reqt, respt) then
sock:close() respt.tmp.sock:close()
return redirect(reqt, respt) authorize(reqt, respt)
elseif should_authorize(reqt, parsed, respt) then
-- drop the body
receive_body(sock, respt.headers, ltn12.sink.null())
-- we are done with this connection
sock:close()
return authorize(reqt, parsed, respt)
elseif should_receive_body(reqt, respt) then elseif should_receive_body(reqt, respt) then
respt.error = receive_body(sock, respt.headers, reqt.sink) receive_body(reqt, respt)
if respt.error then return respt end
sock:close()
return respt
end end
sock:close()
return respt
end end
-----------------------------------------------------------------------------
-- Sends a HTTP request and retrieves the server reply
-- Input
-- reqt: a table with the following fields
-- method: "GET", "PUT", "POST" etc (defaults to "GET")
-- url: request URL, i.e. the document to be retrieved
-- user, password: authentication information
-- headers: request header fields, or nil if none
-- body: request message body as a string, or nil if none
-- redirect: should we refrain from following a server redirect message?
-- Returns
-- respt: a table with the following fields:
-- body: response message body, or nil if failed
-- headers: response header fields, or nil if failed
-- status: server response status line, or nil if failed
-- code: server response status code, or nil if failed
-- error: error message if any
-----------------------------------------------------------------------------
function request(reqt) function request(reqt)
reqt.source = reqt.body and ltn12.source.string(reqt.body) local respt = { tmp = {} }
local t = {} local s, e = pcall(request_p, reqt, respt)
reqt.sink = ltn12.sink.table(t) if not s then respt.error = e end
local respt = request_cb(reqt) if respt.tmp.sock then respt.tmp.sock:close() end
if table.getn(t) > 0 then respt.body = table.concat(t) end respt.tmp = nil
return respt return respt
end end
----------------------------------------------------------------------------- function get(url)
-- Retrieves a URL by the method "GET" local t = {}
-- Input respt = request {
-- url_or_request: target url or request table (a table with the fields: url = url,
-- url: the target URL sink = ltn12.sink.table(t)
-- user: account user name }
-- password: account password) return table.getn(t) > 0 and table.concat(t), respt.headers,
-- Returns respt.code, respt.error
-- body: response message body, or nil if failed
-- headers: response header fields received, or nil if failed
-- code: server response status code, or nil if failed
-- error: error message if any
-----------------------------------------------------------------------------
function get(url_or_request)
local reqt = build_request(url_or_request)
reqt.method = "GET"
local respt = request(reqt)
return respt.body, respt.headers, respt.code, respt.error
end end
----------------------------------------------------------------------------- function post(url, body)
-- Retrieves a URL by the method "POST" local t = {}
-- Input respt = request {
-- url_or_request: target url or request table (a table with the fields: url = url,
-- url: the target URL method = "POST",
-- body: request message body source = ltn12.source.string(body),
-- user: account user name sink = ltn12.sink.table(t),
-- password: account password) headers = { ["content-length"] = string.len(body) }
-- body: request message body, or nil if none }
-- Returns return table.getn(t) > 0 and table.concat(t),
-- body: response message body, or nil if failed respt.headers, respt.code, respt.error
-- headers: response header fields received, or nil if failed
-- code: server response status code, or nil if failed
-- error: error message, or nil if successfull
-----------------------------------------------------------------------------
function post(url_or_request, body)
local reqt = build_request(url_or_request)
reqt.method = "POST"
reqt.body = reqt.body or body
reqt.headers = reqt.headers or
{ ["content-length"] = string.len(reqt.body) }
local respt = request(reqt)
return respt.body, respt.headers, respt.code, respt.error
end end
return socket.http

View File

@ -171,9 +171,8 @@ function sink.file(handle, io_err)
return function(chunk, err) return function(chunk, err)
if not chunk then if not chunk then
handle:close() handle:close()
return nil, err return 1
end else return handle:write(chunk) end
return handle:write(chunk)
end end
else return sink.error(io_err or "unable to open file") end else return sink.error(io_err or "unable to open file") end
end end

View File

@ -619,28 +619,27 @@ static int mime_global_qpwrp(lua_State *L)
* end of line markers each, but \r\n, \n\r etc will only issue *one* * end of line markers each, but \r\n, \n\r etc will only issue *one*
* marker. This covers Mac OS, Mac OS X, VMS, Unix and DOS, as well as * marker. This covers Mac OS, Mac OS X, VMS, Unix and DOS, as well as
* probably other more obscure conventions. * probably other more obscure conventions.
*
* c is the current character being processed
* last is the previous character
\*-------------------------------------------------------------------------*/ \*-------------------------------------------------------------------------*/
#define eolcandidate(c) (c == CR || c == LF) #define eolcandidate(c) (c == CR || c == LF)
static size_t eolprocess(int c, int ctx, const char *marker, static int eolprocess(int c, int last, const char *marker,
luaL_Buffer *buffer) luaL_Buffer *buffer)
{ {
if (eolcandidate(ctx)) { if (eolcandidate(c)) {
luaL_addstring(buffer, marker); if (eolcandidate(last)) {
if (eolcandidate(c)) { if (c == last) luaL_addstring(buffer, marker);
if (c == ctx)
luaL_addstring(buffer, marker);
return 0; return 0;
} else { } else {
luaL_putchar(buffer, c); luaL_addstring(buffer, marker);
return 0; return c;
} }
} else { } else {
if (!eolcandidate(c)) { luaL_putchar(buffer, c);
luaL_putchar(buffer, c); return 0;
return 0;
} else
return c;
} }
} }
/*-------------------------------------------------------------------------*\ /*-------------------------------------------------------------------------*\
@ -661,8 +660,7 @@ static int mime_global_eol(lua_State *L)
luaL_buffinit(L, &buffer); luaL_buffinit(L, &buffer);
/* if the last character was a candidate, we output a new line */ /* if the last character was a candidate, we output a new line */
if (!input) { if (!input) {
if (eolcandidate(ctx)) lua_pushstring(L, marker); lua_pushnil(L);
else lua_pushnil(L);
lua_pushnumber(L, 0); lua_pushnumber(L, 0);
return 2; return 2;
} }

View File

@ -8,7 +8,7 @@ dofile("noglobals.lua")
local host, proxy, request, response, index_file local host, proxy, request, response, index_file
local ignore, expect, index, prefix, cgiprefix, index_crlf local ignore, expect, index, prefix, cgiprefix, index_crlf
socket.http.TIMEOUT = 5 socket.http.TIMEOUT = 10
local t = socket.time() local t = socket.time()
@ -49,7 +49,9 @@ local check_result = function(response, expect, ignore)
for i,v in response do for i,v in response do
if not ignore[i] then if not ignore[i] then
if v ~= expect[i] then if v ~= expect[i] then
print(string.sub(tostring(v), 1, 70)) local f = io.open("err", "w")
f:write(tostring(v), "\n\n versus\n\n", tostring(expect[i]))
f:close()
fail(i .. " differs!") fail(i .. " differs!")
end end
end end
@ -57,8 +59,10 @@ local check_result = function(response, expect, ignore)
for i,v in expect do for i,v in expect do
if not ignore[i] then if not ignore[i] then
if v ~= response[i] then if v ~= response[i] then
local f = io.open("err", "w")
f:write(tostring(response[i]), "\n\n versus\n\n", tostring(v))
v = string.sub(type(v) == "string" and v or "", 1, 70) v = string.sub(type(v) == "string" and v or "", 1, 70)
print(string.sub(tostring(v), 1, 70)) f:close()
fail(i .. " differs!") fail(i .. " differs!")
end end
end end
@ -67,12 +71,14 @@ local check_result = function(response, expect, ignore)
end end
local check_request = function(request, expect, ignore) local check_request = function(request, expect, ignore)
local t
if not request.sink then
request.sink, t = ltn12.sink.table(t)
end
request.source = request.source or
(request.body and ltn12.source.string(request.body))
local response = socket.http.request(request) local response = socket.http.request(request)
check_result(response, expect, ignore) if t and table.getn(t) > 0 then response.body = table.concat(t) end
end
local check_request_cb = function(request, expect, ignore)
local response = socket.http.request_cb(request)
check_result(response, expect, ignore) check_result(response, expect, ignore)
end end
@ -183,7 +189,7 @@ ignore = {
status = 1, status = 1,
headers = 1 headers = 1
} }
check_request_cb(request, expect, ignore) check_request(request, expect, ignore)
back = readfile(index_file .. "-back") back = readfile(index_file .. "-back")
check(back == index) check(back == index)
os.remove(index_file .. "-back") os.remove(index_file .. "-back")
@ -225,19 +231,11 @@ ignore = {
status = 1, status = 1,
headers = 1 headers = 1
} }
check_request_cb(request, expect, ignore) check_request(request, expect, ignore)
back = readfile(index_file .. "-back") back = readfile(index_file .. "-back")
check(back == index) check(back == index)
os.remove(index_file .. "-back") os.remove(index_file .. "-back")
------------------------------------------------------------------------
io.write("testing simple post function with table args: ")
back = socket.http.post {
url = "http://" .. host .. cgiprefix .. "/cat",
body = index
}
check(back == index)
------------------------------------------------------------------------ ------------------------------------------------------------------------
io.write("testing http redirection: ") io.write("testing http redirection: ")
request = { request = {
@ -438,15 +436,6 @@ io.write("testing simple get function: ")
body = socket.http.get("http://" .. host .. prefix .. "/index.html") body = socket.http.get("http://" .. host .. prefix .. "/index.html")
check(body == index) check(body == index)
------------------------------------------------------------------------
io.write("testing simple get function with table args: ")
body = socket.http.get {
url = "http://really:wrong@" .. host .. prefix .. "/auth/index.html",
user = "luasocket",
password = "password"
}
check(body == index)
------------------------------------------------------------------------ ------------------------------------------------------------------------
io.write("testing HEAD method: ") io.write("testing HEAD method: ")
socket.http.TIMEOUT = 1 socket.http.TIMEOUT = 1

View File

@ -17,14 +17,12 @@ function warn(...)
io.stderr:write("WARNING: ", s, "\n") io.stderr:write("WARNING: ", s, "\n")
end end
pad = string.rep(" ", 8192)
function remote(...) function remote(...)
local s = string.format(unpack(arg)) local s = string.format(unpack(arg))
s = string.gsub(s, "\n", ";") s = string.gsub(s, "\n", ";")
s = string.gsub(s, "%s+", " ") s = string.gsub(s, "%s+", " ")
s = string.gsub(s, "^%s*", "") s = string.gsub(s, "^%s*", "")
control:send(pad, s, "\n") control:send(s, "\n")
control:receive() control:receive()
end end
@ -122,7 +120,13 @@ remote (string.format("str = data:receive(%d)",
sent, err = data:send(p1, p2, p3, p4) sent, err = data:send(p1, p2, p3, p4)
if err then fail(err) end if err then fail(err) end
remote "data:send(str); data:close()" remote "data:send(str); data:close()"
bp1, bp2, bp3, bp4, err = data:receive("*l", "*l", string.len(p3), "*a") bp1, err = data:receive()
if err then fail(err) end
bp2, err = data:receive()
if err then fail(err) end
bp3, err = data:receive(string.len(p3))
if err then fail(err) end
bp4, err = data:receive("*a")
if err then fail(err) end if err then fail(err) end
if bp1.."\n" == p1 and bp2.."\r\n" == p2 and bp3 == p3 and bp4 == p4 then if bp1.."\n" == p1 and bp2.."\r\n" == p2 and bp3 == p3 and bp4 == p4 then
pass("patterns match") pass("patterns match")
@ -186,7 +190,7 @@ end
------------------------------------------------------------------------ ------------------------------------------------------------------------
function test_totaltimeoutreceive(len, tm, sl) function test_totaltimeoutreceive(len, tm, sl)
reconnect() reconnect()
local str, err, total local str, err, partial
pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl) pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl)
remote (string.format ([[ remote (string.format ([[
data:settimeout(%d) data:settimeout(%d)
@ -198,9 +202,9 @@ function test_totaltimeoutreceive(len, tm, sl)
data:send(str) data:send(str)
]], 2*tm, len, sl, sl)) ]], 2*tm, len, sl, sl))
data:settimeout(tm, "total") data:settimeout(tm, "total")
str, err, elapsed = data:receive(2*len) str, err, partial, elapsed = data:receive(2*len)
check_timeout(tm, sl, elapsed, err, "receive", "total", check_timeout(tm, sl, elapsed, err, "receive", "total",
string.len(str) == 2*len) string.len(str or partial) == 2*len)
end end
------------------------------------------------------------------------ ------------------------------------------------------------------------
@ -226,7 +230,7 @@ end
------------------------------------------------------------------------ ------------------------------------------------------------------------
function test_blockingtimeoutreceive(len, tm, sl) function test_blockingtimeoutreceive(len, tm, sl)
reconnect() reconnect()
local str, err, total local str, err, partial
pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl) pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl)
remote (string.format ([[ remote (string.format ([[
data:settimeout(%d) data:settimeout(%d)
@ -238,9 +242,9 @@ function test_blockingtimeoutreceive(len, tm, sl)
data:send(str) data:send(str)
]], 2*tm, len, sl, sl)) ]], 2*tm, len, sl, sl))
data:settimeout(tm) data:settimeout(tm)
str, err, elapsed = data:receive(2*len) str, err, partial, elapsed = data:receive(2*len)
check_timeout(tm, sl, elapsed, err, "receive", "blocking", check_timeout(tm, sl, elapsed, err, "receive", "blocking",
string.len(str) == 2*len) string.len(str or partial) == 2*len)
end end
------------------------------------------------------------------------ ------------------------------------------------------------------------
@ -298,7 +302,7 @@ end
------------------------------------------------------------------------ ------------------------------------------------------------------------
function test_closed() function test_closed()
local back, err local back, partial, err
local str = 'little string' local str = 'little string'
reconnect() reconnect()
pass("trying read detection") pass("trying read detection")
@ -308,10 +312,10 @@ function test_closed()
data = nil data = nil
]], str)) ]], str))
-- try to get a line -- try to get a line
back, err = data:receive() back, err, partial = data:receive()
if not err then fail("shold have gotten 'closed'.") if not err then fail("should have gotten 'closed'.")
elseif err ~= "closed" then fail("got '"..err.."' instead of 'closed'.") elseif err ~= "closed" then fail("got '"..err.."' instead of 'closed'.")
elseif str ~= back then fail("didn't receive partial result.") elseif str ~= partial then fail("didn't receive partial result.")
else pass("graceful 'closed' received") end else pass("graceful 'closed' received") end
reconnect() reconnect()
pass("trying write detection") pass("trying write detection")
@ -456,7 +460,6 @@ test_methods(socket.udp(), {
"setpeername", "setpeername",
"setsockname", "setsockname",
"settimeout", "settimeout",
"shutdown",
}) })
test("select function") test("select function")
@ -481,6 +484,7 @@ accept_timeout()
accept_errors() accept_errors()
test("mixed patterns") test("mixed patterns")
test_mixed(1) test_mixed(1)
test_mixed(17) test_mixed(17)

View File

@ -6,7 +6,7 @@ mesgt = {
body = { body = {
preamble = "Some attatched stuff", preamble = "Some attatched stuff",
[1] = { [1] = {
body = "Testing stuffing.\r\n.\r\nGot you.\r\n.Hehehe.\r\n" body = mime.eol(0, "Testing stuffing.\n.\nGot you.\n.Hehehe.\n")
}, },
[2] = { [2] = {
headers = { headers = {
@ -29,7 +29,7 @@ mesgt = {
["content-transfer-encoding"] = "QUOTED-PRINTABLE" ["content-transfer-encoding"] = "QUOTED-PRINTABLE"
}, },
body = ltn12.source.chain( body = ltn12.source.chain(
ltn12.source.file(io.open("message.lua", "rb")), ltn12.source.file(io.open("testmesg.lua", "rb")),
ltn12.filter.chain( ltn12.filter.chain(
mime.normalize(), mime.normalize(),
mime.encode("quoted-printable"), mime.encode("quoted-printable"),
@ -46,8 +46,8 @@ mesgt = {
-- ltn12.pump(source, sink) -- ltn12.pump(source, sink)
print(socket.smtp.send { print(socket.smtp.send {
rcpt = {"<db@werx4.com>", "<diego@cs.princeton.edu>"}, rcpt = "<diego@cs.princeton.edu>",
from = "<diego@cs.princeton.edu>", from = "<diego@cs.princeton.edu>",
source = socket.smtp.message(mesgt), source = socket.smtp.message(mesgt),
server = "smtp.princeton.edu" server = "mail.cs.princeton.edu"
}) })

View File

@ -22,6 +22,8 @@ while 1 do
print("server: closing connection...") print("server: closing connection...")
break break
end end
print(command);
(loadstring(command))() (loadstring(command))()
end end
end end