From 4919a83d2271a9e43b83c7d488e3f94c850681e3 Mon Sep 17 00:00:00 2001 From: Diego Nehab Date: Sun, 21 Mar 2004 07:50:15 +0000 Subject: [PATCH] Changed receive function. Now uniform with all other functions. Returns nil on error, return partial result in the end. http.lua rewritten. --- TODO | 4 + src/buffer.c | 84 +++-- src/http.lua | 759 ++++++++++++++-------------------------------- src/ltn12.lua | 5 +- src/mime.c | 28 +- test/httptest.lua | 43 +-- test/testclnt.lua | 34 ++- test/testmesg.lua | 8 +- test/testsrvr.lua | 2 + 9 files changed, 316 insertions(+), 651 deletions(-) diff --git a/TODO b/TODO index 3f1c71b..110a78c 100644 --- a/TODO +++ b/TODO @@ -19,6 +19,10 @@ * Separar as classes em arquivos * Retorno de sendto em datagram sockets pode ser refused +change mime.eol to output marker on detection of first candidate, instead +of on the second. that way it works in one pass for strings that end with +one candidate. + colocar um userdata com gc metamethod pra chamar sock_close (WSAClose); sources ans sinks are always simple in http and ftp and smtp unify backbone of smtp and ftp diff --git a/src/buffer.c b/src/buffer.c index d9ba779..4bcfa1a 100644 --- a/src/buffer.c +++ b/src/buffer.c @@ -13,9 +13,9 @@ /*=========================================================================*\ * Internal function prototypes \*=========================================================================*/ -static int recvraw(lua_State *L, p_buf buf, size_t wanted); -static int recvline(lua_State *L, p_buf buf); -static int recvall(lua_State *L, p_buf buf); +static int recvraw(p_buf buf, size_t wanted, luaL_Buffer *b); +static int recvline(p_buf buf, luaL_Buffer *b); +static int recvall(p_buf buf, luaL_Buffer *b); static int buf_get(p_buf buf, const char **data, size_t *count); static void buf_skip(p_buf buf, size_t count); static int sendraw(p_buf buf, const char *data, size_t count, size_t *sent); @@ -73,42 +73,34 @@ int buf_meth_send(lua_State *L, p_buf buf) \*-------------------------------------------------------------------------*/ int buf_meth_receive(lua_State *L, p_buf buf) { - int top = lua_gettop(L); - int arg, err = IO_DONE; + int err = IO_DONE, top = lua_gettop(L); p_tm tm = buf->tm; + luaL_Buffer b; + luaL_buffinit(L, &b); tm_markstart(tm); - /* push default pattern if need be */ - if (top < 2) { - lua_pushstring(L, "*l"); - top++; - } - /* make sure we have enough stack space for all returns */ - luaL_checkstack(L, top+LUA_MINSTACK, "too many arguments"); /* receive all patterns */ - for (arg = 2; arg <= top && err == IO_DONE; arg++) { - if (!lua_isnumber(L, arg)) { - static const char *patternnames[] = {"*l", "*a", NULL}; - const char *pattern = lua_isnil(L, arg) ? - "*l" : luaL_checkstring(L, arg); - /* get next pattern */ - switch (luaL_findstring(pattern, patternnames)) { - case 0: /* line pattern */ - err = recvline(L, buf); break; - case 1: /* until closed pattern */ - err = recvall(L, buf); - if (err == IO_CLOSED) err = IO_DONE; - break; - default: /* else it is an error */ - luaL_argcheck(L, 0, arg, "invalid receive pattern"); - break; - } + if (!lua_isnumber(L, 2)) { + static const char *patternnames[] = {"*l", "*a", NULL}; + const char *pattern = luaL_optstring(L, 2, "*l"); + /* get next pattern */ + int p = luaL_findstring(pattern, patternnames); + if (p == 0) err = recvline(buf, &b); + else if (p == 1) err = recvall(buf, &b); + else luaL_argcheck(L, 0, 2, "invalid receive pattern"); /* get a fixed number of bytes */ - } else err = recvraw(L, buf, (size_t) lua_tonumber(L, arg)); + } else err = recvraw(buf, (size_t) lua_tonumber(L, 2), &b); + /* check if there was an error */ + if (err != IO_DONE) { + luaL_pushresult(&b); + io_pusherror(L, err); + lua_pushvalue(L, -2); + lua_pushnil(L); + lua_replace(L, -4); + } else { + luaL_pushresult(&b); + lua_pushnil(L); + lua_pushnil(L); } - /* push nil for each pattern after an error */ - for ( ; arg <= top; arg++) lua_pushnil(L); - /* last return is an error code */ - io_pusherror(L, err); #ifdef LUASOCKET_DEBUG /* push time elapsed during operation as the last return value */ lua_pushnumber(L, (tm_gettime() - tm_getstart(tm))/1000.0); @@ -150,21 +142,18 @@ int sendraw(p_buf buf, const char *data, size_t count, size_t *sent) * Reads a fixed number of bytes (buffered) \*-------------------------------------------------------------------------*/ static -int recvraw(lua_State *L, p_buf buf, size_t wanted) +int recvraw(p_buf buf, size_t wanted, luaL_Buffer *b) { int err = IO_DONE; size_t total = 0; - luaL_Buffer b; - luaL_buffinit(L, &b); while (total < wanted && (err == IO_DONE || err == IO_RETRY)) { size_t count; const char *data; err = buf_get(buf, &data, &count); count = MIN(count, wanted - total); - luaL_addlstring(&b, data, count); + luaL_addlstring(b, data, count); buf_skip(buf, count); total += count; } - luaL_pushresult(&b); return err; } @@ -172,19 +161,17 @@ int recvraw(lua_State *L, p_buf buf, size_t wanted) * Reads everything until the connection is closed (buffered) \*-------------------------------------------------------------------------*/ static -int recvall(lua_State *L, p_buf buf) +int recvall(p_buf buf, luaL_Buffer *b) { int err = IO_DONE; - luaL_Buffer b; - luaL_buffinit(L, &b); while (err == IO_DONE || err == IO_RETRY) { const char *data; size_t count; err = buf_get(buf, &data, &count); - luaL_addlstring(&b, data, count); + luaL_addlstring(b, data, count); buf_skip(buf, count); } - luaL_pushresult(&b); - return err; + if (err == IO_CLOSED) return IO_DONE; + else return err; } /*-------------------------------------------------------------------------*\ @@ -192,18 +179,16 @@ int recvall(lua_State *L, p_buf buf) * are not returned by the function and are discarded from the buffer \*-------------------------------------------------------------------------*/ static -int recvline(lua_State *L, p_buf buf) +int recvline(p_buf buf, luaL_Buffer *b) { int err = IO_DONE; - luaL_Buffer b; - luaL_buffinit(L, &b); while (err == IO_DONE || err == IO_RETRY) { size_t count, pos; const char *data; err = buf_get(buf, &data, &count); pos = 0; while (pos < count && data[pos] != '\n') { /* we ignore all \r's */ - if (data[pos] != '\r') luaL_putchar(&b, data[pos]); + if (data[pos] != '\r') luaL_putchar(b, data[pos]); pos++; } if (pos < count) { /* found '\n' */ @@ -212,7 +197,6 @@ int recvline(lua_State *L, p_buf buf) } else /* reached the end of the buffer */ buf_skip(buf, pos); } - luaL_pushresult(&b); return err; } diff --git a/src/http.lua b/src/http.lua index 629bf65..a10cf50 100644 --- a/src/http.lua +++ b/src/http.lua @@ -39,321 +39,146 @@ local function third(a, b, c) return c end ------------------------------------------------------------------------------ --- Tries to get a pattern from the server and closes socket on error --- sock: socket connected to the server --- pattern: pattern to receive --- Returns --- received pattern on success --- nil followed by error message on error ------------------------------------------------------------------------------ -local function try_receiving(sock, pattern) - local data, err = sock:receive(pattern) - if not data then sock:close() end ---print(data) - return data, err +local function shift(a, b, c, d) + return c, d end ------------------------------------------------------------------------------ --- Tries to send data to the server and closes socket on error --- sock: socket connected to the server --- ...: data to send --- Returns --- err: error message if any, nil if successfull ------------------------------------------------------------------------------ -local function try_sending(sock, ...) - local sent, err = sock:send(unpack(arg)) - if not sent then sock:close() end ---io.write(unpack(arg)) - return err -end +-- resquest_p forward declaration +local request_p ------------------------------------------------------------------------------ --- Receive server reply messages, parsing for status code --- Input --- sock: socket connected to the server --- Returns --- code: server status code or nil if error --- line: full HTTP status line --- err: error message if any ------------------------------------------------------------------------------ -local function receive_status(sock) - local line, err = try_receiving(sock) - if not err then - local code = third(string.find(line, "HTTP/%d*%.%d* (%d%d%d)")) - return tonumber(code), line - else return nil, nil, err end -end - ------------------------------------------------------------------------------ --- Receive and parse response header fields --- Input --- sock: socket connected to the server --- headers: a table that might already contain headers --- Returns --- headers: a table with all headers fields in the form --- {name_1 = "value_1", name_2 = "value_2" ... name_n = "value_n"} --- all name_i are lowercase --- nil and error message in case of error ------------------------------------------------------------------------------ local function receive_headers(sock, headers) - local line, err - local name, value, _ - headers = headers or {} + local line, name, value -- get first line - line, err = try_receiving(sock) - if err then return nil, err end + line = socket.try(sock:receive()) -- headers go until a blank line is found while line ~= "" do -- get field-name and value - _,_, name, value = string.find(line, "^(.-):%s*(.*)") - if not name or not value then - sock:close() - return nil, "malformed reponse headers" - end + name, value = shift(string.find(line, "^(.-):%s*(.*)")) + assert(name and value, "malformed reponse headers") name = string.lower(name) -- get next line (value might be folded) - line, err = try_receiving(sock) - if err then return nil, err end + line = socket.try(sock:receive()) -- unfold any folded values - while not err and string.find(line, "^%s") do + while string.find(line, "^%s") do value = value .. line - line, err = try_receiving(sock) - if err then return nil, err end + line = socket.try(sock:receive()) end -- save pair in table if headers[name] then headers[name] = headers[name] .. ", " .. value else headers[name] = value end end - return headers end ------------------------------------------------------------------------------ --- Aborts a sink with an error message --- Input --- cb: callback function --- err: error message to pass to callback --- Returns --- callback return or if nil err ------------------------------------------------------------------------------ local function abort(cb, err) local go, cb_err = cb(nil, err) - return cb_err or err + error(cb_err or err) end ------------------------------------------------------------------------------ --- Receives a chunked message body --- Input --- sock: socket connected to the server --- headers: header set in which to include trailer headers --- sink: response message body sink --- Returns --- nil if successfull or an error message in case of error ------------------------------------------------------------------------------ -local function receive_body_bychunks(sock, headers, sink) - local chunk, size, line, err, go +local function hand(cb, chunk) + local go, cb_err = cb(chunk) + assert(go, cb_err or "aborted by callback") +end + +local function receive_body_bychunks(sock, sink) while 1 do -- get chunk size, skip extention - line, err = try_receiving(sock) - if err then return abort(sink, err) end - size = tonumber(string.gsub(line, ";.*", ""), 16) - if not size then return abort(sink, "invalid chunk size") end + local line, err = sock:receive() + if err then abort(sink, err) end + local size = tonumber(string.gsub(line, ";.*", ""), 16) + if not size then abort(sink, "invalid chunk size") end -- was it the last chunk? if size <= 0 then break end -- get chunk - chunk, err = try_receiving(sock, size) - if err then return abort(sink, err) end + local chunk, err = sock:receive(size) + if err then abort(sink, err) end -- pass chunk to callback - go, err = sink(chunk) - -- see if callback aborted - if not go then return err or "aborted by callback" end + hand(sink, chunk) -- skip CRLF on end of chunk - err = second(try_receiving(sock)) - if err then return abort(sink, err) end + err = second(sock:receive()) + if err then abort(sink, err) end end - -- servers shouldn't send trailer headers, but who trusts them? - err = second(receive_headers(sock, headers)) - if err then return abort(sink, err) end -- let callback know we are done - return second(sink(nil)) + hand(sink, nil) + -- servers shouldn't send trailer headers, but who trusts them? + receive_headers(sock, {}) end ------------------------------------------------------------------------------ --- Receives a message body by content-length --- Input --- sock: socket connected to the server --- length: message body length --- sink: response message body sink --- Returns --- nil if successfull or an error message in case of error ------------------------------------------------------------------------------ local function receive_body_bylength(sock, length, sink) while length > 0 do local size = math.min(BLOCKSIZE, length) local chunk, err = sock:receive(size) - local go, cb_err = sink(chunk) + if err then abort(sink, err) end length = length - string.len(chunk) - -- see if callback aborted - if not go then return cb_err or "aborted by callback" end -- see if there was an error - if err and length > 0 then return abort(sink, err) end + hand(sink, chunk) end - return second(sink(nil)) + -- let callback know we are done + hand(sink, nil) end ------------------------------------------------------------------------------ --- Receives a message body until the conection is closed --- Input --- sock: socket connected to the server --- sink: response message body sink --- Returns --- nil if successfull or an error message in case of error ------------------------------------------------------------------------------ local function receive_body_untilclosed(sock, sink) - while 1 do - local chunk, err = sock:receive(BLOCKSIZE) - local go, cb_err = sink(chunk) - -- see if callback aborted - if not go then return cb_err or "aborted by callback" end + while true do + local chunk, err, partial = sock:receive(BLOCKSIZE) -- see if we are done - if err == "closed" then return chunk and second(sink(nil)) end + if err == "closed" then + hand(sink, partial) + break + end + hand(sink, chunk) -- see if there was an error - if err then return abort(sink, err) end + if err then abort(sink, err) end end + -- let callback know we are done + hand(sink, nil) end ------------------------------------------------------------------------------ --- Receives the HTTP response body --- Input --- sock: socket connected to the server --- headers: response header fields --- sink: response message body sink --- Returns --- nil if successfull or an error message in case of error ------------------------------------------------------------------------------ -local function receive_body(sock, headers, sink) - -- make sure sink is not fancy - sink = ltn12.sink.simplify(sink) +local function receive_body(reqt, respt) + local sink = reqt.sink or ltn12.sink.null() + local headers = respt.headers + local sock = respt.tmp.sock local te = headers["transfer-encoding"] if te and te ~= "identity" then -- get by chunked transfer-coding of message body - return receive_body_bychunks(sock, headers, sink) + receive_body_bychunks(sock, sink) elseif tonumber(headers["content-length"]) then -- get by content-length local length = tonumber(headers["content-length"]) - return receive_body_bylength(sock, length, sink) + receive_body_bylength(sock, length, sink) else -- get it all until connection closes - return receive_body_untilclosed(sock, sink) + receive_body_untilclosed(sock, sink) end end ------------------------------------------------------------------------------ --- Sends the HTTP request message body in chunks --- Input --- data: data connection --- source: request message body source --- Returns --- nil if successfull, or an error message in case of error ------------------------------------------------------------------------------ local function send_body_bychunks(data, source) - while 1 do - local chunk, cb_err = source() - -- check if callback aborted - if not chunk then return cb_err or "aborted by callback" end - -- if we are done, send last-chunk - if chunk == "" then return try_sending(data, "0\r\n\r\n") end - -- else send middle chunk - local err = try_sending(data, - string.format("%X\r\n", string.len(chunk)), - chunk, - "\r\n" - ) - if err then return err end + while true do + local chunk, err = source() + assert(chunk or not err, err) + if not chunk then break end + socket.try(data:send(string.format("%X\r\n", string.len(chunk)))) + socket.try(data:send(chunk, "\r\n")) end + socket.try(data:send("0\r\n\r\n")) end ------------------------------------------------------------------------------ --- Sends the HTTP request message body --- Input --- data: data connection --- source: request message body source --- Returns --- nil if successfull, or an error message in case of error ------------------------------------------------------------------------------ local function send_body(data, source) - while 1 do - local chunk, cb_err = source() - -- check if callback is done - if not chunk then return cb_err end - -- send data - local err = try_sending(data, chunk) - if err then return err end + while true do + local chunk, err = source() + assert(chunk or not err, err) + if not chunk then break end + socket.try(data:send(chunk)) end end ------------------------------------------------------------------------------ --- Sends request headers --- Input --- sock: server socket --- headers: table with headers to be sent --- Returns --- err: error message if any ------------------------------------------------------------------------------ local function send_headers(sock, headers) - local err - headers = headers or {} -- send request headers - for i, v in headers do - err = try_sending(sock, i .. ": " .. v .. "\r\n") - if err then return err end + for i, v in pairs(headers) do + socket.try(sock:send(i .. ": " .. v .. "\r\n")) end -- mark end of request headers - return try_sending(sock, "\r\n") + socket.try(sock:send("\r\n")) end ------------------------------------------------------------------------------ --- Sends a HTTP request message through socket --- Input --- sock: socket connected to the server --- method: request method to be used --- uri: request uri --- headers: request headers to be sent --- source: request message body source --- Returns --- err: nil in case of success, error message otherwise ------------------------------------------------------------------------------ -local function send_request(sock, method, uri, headers, source) - local chunk, size, done, err - -- send request line - err = try_sending(sock, method .. " " .. uri .. " HTTP/1.1\r\n") - if err then return err end - if source and not headers["content-length"] then - headers["transfer-encoding"] = "chunked" - end - -- send request headers - err = send_headers(sock, headers) - if err then return err end - -- send request message body, if any - if source then - -- make sure source is not fancy - source = ltn12.source.simplify(source) - if headers["content-length"] then - return send_body(sock, source) - else - return send_body_bychunks(sock, source) - end - end -end - ------------------------------------------------------------------------------ --- Determines if we should read a message body from the server response --- Input --- reqt: a table with the original request information --- respt: a table with the server response information --- Returns --- 1 if a message body should be processed, nil otherwise ------------------------------------------------------------------------------ local function should_receive_body(reqt, respt) if reqt.method == "HEAD" then return nil end if respt.code == 204 or respt.code == 304 then return nil end @@ -361,125 +186,17 @@ local function should_receive_body(reqt, respt) return 1 end ------------------------------------------------------------------------------ --- Converts field names to lowercase and adds a few needed headers --- Input --- headers: request header fields --- parsed: parsed request URL --- Returns --- lower: a table with the same headers, but with lowercase field names ------------------------------------------------------------------------------ -local function fill_headers(headers, parsed) - local lower = {} - headers = headers or {} - -- set default headers - lower["user-agent"] = USERAGENT - -- override with user values - for i,v in headers do - lower[string.lower(i)] = v - end - lower["host"] = parsed.host - -- this cannot be overriden - lower["connection"] = "close" - return lower +local function receive_status(reqt, respt) + local sock = respt.tmp.sock + local status = socket.try(sock:receive()) + local code = third(string.find(status, "HTTP/%d*%.%d* (%d%d%d)")) + -- store results + respt.code, respt.status = tonumber(code), status end ------------------------------------------------------------------------------ --- Decides wether we should follow retry with authorization formation --- Input --- reqt: a table with the original request information --- parsed: parsed request URL --- respt: a table with the server response information --- Returns --- 1 if we should retry, nil otherwise ------------------------------------------------------------------------------ -local function should_authorize(reqt, parsed, respt) - -- if there has been an authorization attempt, it must have failed - if reqt.headers["authorization"] then return nil end - -- if we don't have authorization information, we can't retry - if parsed.user and parsed.password then return 1 - else return nil end -end - ------------------------------------------------------------------------------ --- Returns the result of retrying a request with authorization information --- Input --- reqt: a table with the original request information --- parsed: parsed request URL --- Returns --- respt: result of target authorization ------------------------------------------------------------------------------ -local function authorize(reqt, parsed) - reqt.headers["authorization"] = "Basic " .. - (mime.b64(parsed.user .. ":" .. parsed.password)) - local autht = { - nredirects = reqt.nredirects, - method = reqt.method, - url = reqt.url, - source = reqt.source, - sink = reqt.sink, - headers = reqt.headers, - timeout = reqt.timeout, - proxy = reqt.proxy, - } - return request_cb(autht) -end - ------------------------------------------------------------------------------ --- Decides wether we should follow a server redirect message --- Input --- reqt: a table with the original request information --- respt: a table with the server response information --- Returns --- 1 if we should redirect, nil otherwise ------------------------------------------------------------------------------ -local function should_redirect(reqt, respt) - return (reqt.redirect ~= false) and - (respt.code == 301 or respt.code == 302) and - (reqt.method == "GET" or reqt.method == "HEAD") and - not (reqt.nredirects and reqt.nredirects >= 5) -end - ------------------------------------------------------------------------------ --- Returns the result of a request following a server redirect message. --- Input --- reqt: a table with the original request information --- respt: response table of previous attempt --- Returns --- respt: result of target redirection ------------------------------------------------------------------------------ -local function redirect(reqt, respt) - local nredirects = reqt.nredirects or 0 - nredirects = nredirects + 1 - local redirt = { - nredirects = nredirects, - method = reqt.method, - -- the RFC says the redirect URL has to be absolute, but some - -- servers do not respect that - url = socket.url.absolute(reqt.url, respt.headers["location"]), - source = reqt.source, - sink = reqt.sink, - headers = reqt.headers, - timeout = reqt.timeout, - proxy = reqt.proxy - } - respt = request_cb(redirt) - -- we pass the location header as a clue we tried to redirect - if respt.headers then respt.headers.location = redirt.url end - return respt -end - ------------------------------------------------------------------------------ --- Computes the request URI from the parsed request URL --- If we are using a proxy, we use the absoluteURI format. --- Otherwise, we use the abs_path format. --- Input --- parsed: parsed URL --- Returns --- uri: request URI for parsed URL ------------------------------------------------------------------------------ -local function request_uri(reqt, parsed) +local function request_uri(reqt, respt) local url + local parsed = respt.tmp.parsed if not reqt.proxy then url = { path = parsed.path, @@ -487,219 +204,187 @@ local function request_uri(reqt, parsed) query = parsed.query, fragment = parsed.fragment } - else url = parsed end + else url = respt.tmp.parsed end return socket.url.build(url) end ------------------------------------------------------------------------------ --- Builds a request table from a URL or request table --- Input --- url_or_request: target url or request table (a table with the fields: --- url: the target URL --- user: account user name --- password: account password) --- Returns --- reqt: request table ------------------------------------------------------------------------------ -local function build_request(data) - local reqt = {} - if type(data) == "table" then - for i, v in data - do reqt[i] = v - end - else reqt.url = data end - return reqt +local function send_request(reqt, respt) + local uri = request_uri(reqt, respt) + local sock = respt.tmp.sock + local headers = respt.tmp.headers + -- send request line + socket.try(sock:send((reqt.method or "GET") + .. " " .. uri .. " HTTP/1.1\r\n")) + -- send request headers headeres + if reqt.source and not headers["content-length"] then + headers["transfer-encoding"] = "chunked" + end + send_headers(sock, headers) + -- send request message body, if any + if reqt.source then + if headers["content-length"] then send_body(sock, reqt.source) + else send_body_bychunks(sock, reqt.source) end + end end ------------------------------------------------------------------------------ --- Connects to a server, be it a proxy or not --- Input --- reqt: the request table --- parsed: the parsed request url --- Returns --- sock: connection socket, or nil in case of error --- err: error message ------------------------------------------------------------------------------ -local function try_connect(reqt, parsed) - reqt.proxy = reqt.proxy or PROXY +local function open(reqt, respt) + local parsed = respt.tmp.parsed + local proxy = reqt.proxy or PROXY local host, port - if reqt.proxy then - local pproxy = socket.url.parse(reqt.proxy) - if not pproxy.port or not pproxy.host then - return nil, "invalid proxy" - end + if proxy then + local pproxy = socket.url.parse(proxy) + assert(pproxy.port and pproxy.host, "invalid proxy") host, port = pproxy.host, pproxy.port else host, port = parsed.host, parsed.port end - local sock, ret, err - sock, err = socket.tcp() - if not sock then return nil, err end + local sock = socket.try(socket.tcp()) + -- store results + respt.tmp.sock = sock sock:settimeout(reqt.timeout or TIMEOUT) - ret, err = sock:connect(host, port) - if not ret then - sock:close() - return nil, err - end - return sock + socket.try(sock:connect(host, port)) end ------------------------------------------------------------------------------ --- Sends a HTTP request and retrieves the server reply using callbacks to --- send the request body and receive the response body --- Input --- reqt: a table with the following fields --- method: "GET", "PUT", "POST" etc (defaults to "GET") --- url: target uniform resource locator --- user, password: authentication information --- headers: request headers to send, or nil if none --- source: request message body source, or nil if none --- sink: response message body sink --- redirect: should we refrain from following a server redirect message? --- Returns --- respt: a table with the following fields: --- headers: response header fields received, or nil if failed --- status: server response status line, or nil if failed --- code: server status code, or nil if failed --- error: error message, or nil if successfull ------------------------------------------------------------------------------ -function request_cb(reqt) - local sock, ret +function adjust_headers(reqt, respt) + local lower = {} + local headers = reqt.headers or {} + -- set default headers + lower["user-agent"] = USERAGENT + -- override with user values + for i,v in headers do + lower[string.lower(i)] = v + end + lower["host"] = respt.tmp.parsed.host + -- this cannot be overriden + lower["connection"] = "close" + -- store results + respt.tmp.headers = lower +end + +function parse_url(reqt, respt) + -- parse url with default fields local parsed = socket.url.parse(reqt.url, { host = "", port = PORT, path ="/", scheme = "http" }) - local respt = {} + -- scheme has to be http if parsed.scheme ~= "http" then - respt.error = string.format("unknown scheme '%s'", parsed.scheme) - return respt - end + error(string.format("unknown scheme '%s'", parsed.scheme)) + end -- explicit authentication info overrides that given by the URL parsed.user = reqt.user or parsed.user parsed.password = reqt.password or parsed.password - -- default method - reqt.method = reqt.method or "GET" - -- fill default headers - reqt.headers = fill_headers(reqt.headers, parsed) - -- try to connect to server - sock, respt.error = try_connect(reqt, parsed) - if not sock then return respt end - -- send request message - respt.error = send_request(sock, reqt.method, - request_uri(reqt, parsed), reqt.headers, reqt.source) - if respt.error then - sock:close() - return respt + -- store results + respt.tmp.parsed = parsed +end + +local function should_authorize(reqt, respt) + -- if there has been an authorization attempt, it must have failed + if reqt.headers and reqt.headers["authorization"] then return nil end + -- if we don't have authorization information, we can't retry + return respt.tmp.parsed.user and respt.tmp.parsed.password +end + +local function clone(headers) + if not headers then return nil end + local copy = {} + for i,v in pairs(headers) do + copy[i] = v end - -- get server response message - respt.code, respt.status, respt.error = receive_status(sock) - if respt.error then return respt end - -- deal with continue 100 - -- servers should not send them, but some do! - if respt.code == 100 then - respt.headers, respt.error = receive_headers(sock, {}) - if respt.error then return respt end - respt.code, respt.status, respt.error = receive_status(sock) - if respt.error then return respt end - end - -- receive all headers - respt.headers, respt.error = receive_headers(sock, {}) - if respt.error then return respt end - -- decide what to do based on request and response parameters + return copy +end + +local function authorize(reqt, respt) + local headers = clone(reqt.headers) or {} + local parsed = respt.tmp.parsed + headers["authorization"] = "Basic " .. + (mime.b64(parsed.user .. ":" .. parsed.password)) + local autht = { + method = reqt.method, + url = reqt.url, + source = reqt.source, + sink = reqt.sink, + headers = headers, + timeout = reqt.timeout, + proxy = reqt.proxy, + } + request_p(autht, respt) +end + +local function should_redirect(reqt, respt) + return (reqt.redirect ~= false) and + (respt.code == 301 or respt.code == 302) and + (not reqt.method or reqt.method == "GET" or reqt.method == "HEAD") + and (not respt.tmp.nredirects or respt.tmp.nredirects < 5) +end + +local function redirect(reqt, respt) + respt.tmp.nredirects = (respt.tmp.nredirects or 0) + 1 + local redirt = { + method = reqt.method, + -- the RFC says the redirect URL has to be absolute, but some + -- servers do not respect that + url = socket.url.absolute(reqt.url, respt.headers["location"]), + source = reqt.source, + sink = reqt.sink, + headers = reqt.headers, + timeout = reqt.timeout, + proxy = reqt.proxy + } + request_p(redirt, respt) + -- we pass the location header as a clue we redirected + if respt.headers then respt.headers.location = redirt.url end +end + +function request_p(reqt, respt) + parse_url(reqt, respt) + adjust_headers(reqt, respt) + open(reqt, respt) + send_request(reqt, respt) + receive_status(reqt, respt) + respt.headers = {} + receive_headers(respt.tmp.sock, respt.headers) if should_redirect(reqt, respt) then - -- drop the body - receive_body(sock, respt.headers, ltn12.sink.null()) - -- we are done with this connection - sock:close() - return redirect(reqt, respt) - elseif should_authorize(reqt, parsed, respt) then - -- drop the body - receive_body(sock, respt.headers, ltn12.sink.null()) - -- we are done with this connection - sock:close() - return authorize(reqt, parsed, respt) + respt.tmp.sock:close() + redirect(reqt, respt) + elseif should_authorize(reqt, respt) then + respt.tmp.sock:close() + authorize(reqt, respt) elseif should_receive_body(reqt, respt) then - respt.error = receive_body(sock, respt.headers, reqt.sink) - if respt.error then return respt end - sock:close() - return respt + receive_body(reqt, respt) end - sock:close() - return respt end ------------------------------------------------------------------------------ --- Sends a HTTP request and retrieves the server reply --- Input --- reqt: a table with the following fields --- method: "GET", "PUT", "POST" etc (defaults to "GET") --- url: request URL, i.e. the document to be retrieved --- user, password: authentication information --- headers: request header fields, or nil if none --- body: request message body as a string, or nil if none --- redirect: should we refrain from following a server redirect message? --- Returns --- respt: a table with the following fields: --- body: response message body, or nil if failed --- headers: response header fields, or nil if failed --- status: server response status line, or nil if failed --- code: server response status code, or nil if failed --- error: error message if any ------------------------------------------------------------------------------ function request(reqt) - reqt.source = reqt.body and ltn12.source.string(reqt.body) - local t = {} - reqt.sink = ltn12.sink.table(t) - local respt = request_cb(reqt) - if table.getn(t) > 0 then respt.body = table.concat(t) end + local respt = { tmp = {} } + local s, e = pcall(request_p, reqt, respt) + if not s then respt.error = e end + if respt.tmp.sock then respt.tmp.sock:close() end + respt.tmp = nil return respt end ------------------------------------------------------------------------------ --- Retrieves a URL by the method "GET" --- Input --- url_or_request: target url or request table (a table with the fields: --- url: the target URL --- user: account user name --- password: account password) --- Returns --- body: response message body, or nil if failed --- headers: response header fields received, or nil if failed --- code: server response status code, or nil if failed --- error: error message if any ------------------------------------------------------------------------------ -function get(url_or_request) - local reqt = build_request(url_or_request) - reqt.method = "GET" - local respt = request(reqt) - return respt.body, respt.headers, respt.code, respt.error +function get(url) + local t = {} + respt = request { + url = url, + sink = ltn12.sink.table(t) + } + return table.getn(t) > 0 and table.concat(t), respt.headers, + respt.code, respt.error end ------------------------------------------------------------------------------ --- Retrieves a URL by the method "POST" --- Input --- url_or_request: target url or request table (a table with the fields: --- url: the target URL --- body: request message body --- user: account user name --- password: account password) --- body: request message body, or nil if none --- Returns --- body: response message body, or nil if failed --- headers: response header fields received, or nil if failed --- code: server response status code, or nil if failed --- error: error message, or nil if successfull ------------------------------------------------------------------------------ -function post(url_or_request, body) - local reqt = build_request(url_or_request) - reqt.method = "POST" - reqt.body = reqt.body or body - reqt.headers = reqt.headers or - { ["content-length"] = string.len(reqt.body) } - local respt = request(reqt) - return respt.body, respt.headers, respt.code, respt.error +function post(url, body) + local t = {} + respt = request { + url = url, + method = "POST", + source = ltn12.source.string(body), + sink = ltn12.sink.table(t), + headers = { ["content-length"] = string.len(body) } + } + return table.getn(t) > 0 and table.concat(t), + respt.headers, respt.code, respt.error end - -return socket.http diff --git a/src/ltn12.lua b/src/ltn12.lua index ef6247d..dc49d80 100644 --- a/src/ltn12.lua +++ b/src/ltn12.lua @@ -171,9 +171,8 @@ function sink.file(handle, io_err) return function(chunk, err) if not chunk then handle:close() - return nil, err - end - return handle:write(chunk) + return 1 + else return handle:write(chunk) end end else return sink.error(io_err or "unable to open file") end end diff --git a/src/mime.c b/src/mime.c index 77f3ae1..7bfa6aa 100644 --- a/src/mime.c +++ b/src/mime.c @@ -619,28 +619,27 @@ static int mime_global_qpwrp(lua_State *L) * end of line markers each, but \r\n, \n\r etc will only issue *one* * marker. This covers Mac OS, Mac OS X, VMS, Unix and DOS, as well as * probably other more obscure conventions. +* +* c is the current character being processed +* last is the previous character \*-------------------------------------------------------------------------*/ #define eolcandidate(c) (c == CR || c == LF) -static size_t eolprocess(int c, int ctx, const char *marker, +static int eolprocess(int c, int last, const char *marker, luaL_Buffer *buffer) { - if (eolcandidate(ctx)) { - luaL_addstring(buffer, marker); - if (eolcandidate(c)) { - if (c == ctx) - luaL_addstring(buffer, marker); + if (eolcandidate(c)) { + if (eolcandidate(last)) { + if (c == last) luaL_addstring(buffer, marker); return 0; } else { - luaL_putchar(buffer, c); - return 0; + luaL_addstring(buffer, marker); + return c; } } else { - if (!eolcandidate(c)) { - luaL_putchar(buffer, c); - return 0; - } else - return c; + luaL_putchar(buffer, c); + return 0; } + } /*-------------------------------------------------------------------------*\ @@ -661,8 +660,7 @@ static int mime_global_eol(lua_State *L) luaL_buffinit(L, &buffer); /* if the last character was a candidate, we output a new line */ if (!input) { - if (eolcandidate(ctx)) lua_pushstring(L, marker); - else lua_pushnil(L); + lua_pushnil(L); lua_pushnumber(L, 0); return 2; } diff --git a/test/httptest.lua b/test/httptest.lua index 04c0ed0..ddeea50 100644 --- a/test/httptest.lua +++ b/test/httptest.lua @@ -8,7 +8,7 @@ dofile("noglobals.lua") local host, proxy, request, response, index_file local ignore, expect, index, prefix, cgiprefix, index_crlf -socket.http.TIMEOUT = 5 +socket.http.TIMEOUT = 10 local t = socket.time() @@ -49,7 +49,9 @@ local check_result = function(response, expect, ignore) for i,v in response do if not ignore[i] then if v ~= expect[i] then - print(string.sub(tostring(v), 1, 70)) + local f = io.open("err", "w") + f:write(tostring(v), "\n\n versus\n\n", tostring(expect[i])) + f:close() fail(i .. " differs!") end end @@ -57,8 +59,10 @@ local check_result = function(response, expect, ignore) for i,v in expect do if not ignore[i] then if v ~= response[i] then + local f = io.open("err", "w") + f:write(tostring(response[i]), "\n\n versus\n\n", tostring(v)) v = string.sub(type(v) == "string" and v or "", 1, 70) - print(string.sub(tostring(v), 1, 70)) + f:close() fail(i .. " differs!") end end @@ -67,12 +71,14 @@ local check_result = function(response, expect, ignore) end local check_request = function(request, expect, ignore) + local t + if not request.sink then + request.sink, t = ltn12.sink.table(t) + end + request.source = request.source or + (request.body and ltn12.source.string(request.body)) local response = socket.http.request(request) - check_result(response, expect, ignore) -end - -local check_request_cb = function(request, expect, ignore) - local response = socket.http.request_cb(request) + if t and table.getn(t) > 0 then response.body = table.concat(t) end check_result(response, expect, ignore) end @@ -183,7 +189,7 @@ ignore = { status = 1, headers = 1 } -check_request_cb(request, expect, ignore) +check_request(request, expect, ignore) back = readfile(index_file .. "-back") check(back == index) os.remove(index_file .. "-back") @@ -225,19 +231,11 @@ ignore = { status = 1, headers = 1 } -check_request_cb(request, expect, ignore) +check_request(request, expect, ignore) back = readfile(index_file .. "-back") check(back == index) os.remove(index_file .. "-back") ------------------------------------------------------------------------- -io.write("testing simple post function with table args: ") -back = socket.http.post { - url = "http://" .. host .. cgiprefix .. "/cat", - body = index -} -check(back == index) - ------------------------------------------------------------------------ io.write("testing http redirection: ") request = { @@ -438,15 +436,6 @@ io.write("testing simple get function: ") body = socket.http.get("http://" .. host .. prefix .. "/index.html") check(body == index) ------------------------------------------------------------------------- -io.write("testing simple get function with table args: ") -body = socket.http.get { - url = "http://really:wrong@" .. host .. prefix .. "/auth/index.html", - user = "luasocket", - password = "password" -} -check(body == index) - ------------------------------------------------------------------------ io.write("testing HEAD method: ") socket.http.TIMEOUT = 1 diff --git a/test/testclnt.lua b/test/testclnt.lua index 1b64abd..ecf419b 100644 --- a/test/testclnt.lua +++ b/test/testclnt.lua @@ -17,14 +17,12 @@ function warn(...) io.stderr:write("WARNING: ", s, "\n") end -pad = string.rep(" ", 8192) - function remote(...) local s = string.format(unpack(arg)) s = string.gsub(s, "\n", ";") s = string.gsub(s, "%s+", " ") s = string.gsub(s, "^%s*", "") - control:send(pad, s, "\n") + control:send(s, "\n") control:receive() end @@ -122,7 +120,13 @@ remote (string.format("str = data:receive(%d)", sent, err = data:send(p1, p2, p3, p4) if err then fail(err) end remote "data:send(str); data:close()" - bp1, bp2, bp3, bp4, err = data:receive("*l", "*l", string.len(p3), "*a") + bp1, err = data:receive() + if err then fail(err) end + bp2, err = data:receive() + if err then fail(err) end + bp3, err = data:receive(string.len(p3)) + if err then fail(err) end + bp4, err = data:receive("*a") if err then fail(err) end if bp1.."\n" == p1 and bp2.."\r\n" == p2 and bp3 == p3 and bp4 == p4 then pass("patterns match") @@ -186,7 +190,7 @@ end ------------------------------------------------------------------------ function test_totaltimeoutreceive(len, tm, sl) reconnect() - local str, err, total + local str, err, partial pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl) remote (string.format ([[ data:settimeout(%d) @@ -198,9 +202,9 @@ function test_totaltimeoutreceive(len, tm, sl) data:send(str) ]], 2*tm, len, sl, sl)) data:settimeout(tm, "total") - str, err, elapsed = data:receive(2*len) + str, err, partial, elapsed = data:receive(2*len) check_timeout(tm, sl, elapsed, err, "receive", "total", - string.len(str) == 2*len) + string.len(str or partial) == 2*len) end ------------------------------------------------------------------------ @@ -226,7 +230,7 @@ end ------------------------------------------------------------------------ function test_blockingtimeoutreceive(len, tm, sl) reconnect() - local str, err, total + local str, err, partial pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl) remote (string.format ([[ data:settimeout(%d) @@ -238,9 +242,9 @@ function test_blockingtimeoutreceive(len, tm, sl) data:send(str) ]], 2*tm, len, sl, sl)) data:settimeout(tm) - str, err, elapsed = data:receive(2*len) + str, err, partial, elapsed = data:receive(2*len) check_timeout(tm, sl, elapsed, err, "receive", "blocking", - string.len(str) == 2*len) + string.len(str or partial) == 2*len) end ------------------------------------------------------------------------ @@ -298,7 +302,7 @@ end ------------------------------------------------------------------------ function test_closed() - local back, err + local back, partial, err local str = 'little string' reconnect() pass("trying read detection") @@ -308,10 +312,10 @@ function test_closed() data = nil ]], str)) -- try to get a line - back, err = data:receive() - if not err then fail("shold have gotten 'closed'.") + back, err, partial = data:receive() + if not err then fail("should have gotten 'closed'.") elseif err ~= "closed" then fail("got '"..err.."' instead of 'closed'.") - elseif str ~= back then fail("didn't receive partial result.") + elseif str ~= partial then fail("didn't receive partial result.") else pass("graceful 'closed' received") end reconnect() pass("trying write detection") @@ -456,7 +460,6 @@ test_methods(socket.udp(), { "setpeername", "setsockname", "settimeout", - "shutdown", }) test("select function") @@ -481,6 +484,7 @@ accept_timeout() accept_errors() + test("mixed patterns") test_mixed(1) test_mixed(17) diff --git a/test/testmesg.lua b/test/testmesg.lua index 228bbe4..8b33133 100644 --- a/test/testmesg.lua +++ b/test/testmesg.lua @@ -6,7 +6,7 @@ mesgt = { body = { preamble = "Some attatched stuff", [1] = { - body = "Testing stuffing.\r\n.\r\nGot you.\r\n.Hehehe.\r\n" + body = mime.eol(0, "Testing stuffing.\n.\nGot you.\n.Hehehe.\n") }, [2] = { headers = { @@ -29,7 +29,7 @@ mesgt = { ["content-transfer-encoding"] = "QUOTED-PRINTABLE" }, body = ltn12.source.chain( - ltn12.source.file(io.open("message.lua", "rb")), + ltn12.source.file(io.open("testmesg.lua", "rb")), ltn12.filter.chain( mime.normalize(), mime.encode("quoted-printable"), @@ -46,8 +46,8 @@ mesgt = { -- ltn12.pump(source, sink) print(socket.smtp.send { - rcpt = {"", ""}, + rcpt = "", from = "", source = socket.smtp.message(mesgt), - server = "smtp.princeton.edu" + server = "mail.cs.princeton.edu" }) diff --git a/test/testsrvr.lua b/test/testsrvr.lua index 99b54e5..5c05239 100644 --- a/test/testsrvr.lua +++ b/test/testsrvr.lua @@ -22,6 +22,8 @@ while 1 do print("server: closing connection...") break end +print(command); + (loadstring(command))() end end