From 8d4e240f6ae50d9b22ddc44f5e207018935da907 Mon Sep 17 00:00:00 2001
From: Diego Nehab
@@ -150,7 +150,7 @@ error.
@@ -206,7 +206,7 @@ Here are a few examples with the simple interface:
@@ -239,7 +239,7 @@ and
-- loads the FTP module and any libraries it requires
-local ftp = require("ftp")
+local ftp = require("socket.ftp")
-- load the ftp support
-local ftp = require("ftp")
+local ftp = require("socket.ftp")
-- Log as user "anonymous" on server "ftp.tecgraf.puc-rio.br",
-- and get file "lua.tar.gz" from directory "pub/lua" as binary.
@@ -159,9 +159,9 @@ f, e = ftp.get("ftp://ftp.tecgraf.puc-rio.br/pub/lua/lua.tar.gz;type=i")
-- load needed modules
-local ftp = require("ftp")
+local ftp = require("socket.ftp")
local ltn12 = require("ltn12")
-local url = require("url")
+local url = require("socket.url")
-- a function that returns a directory listing
function nlst(u)
@@ -230,7 +230,7 @@ message describing the reason for failure.
-- load the ftp support
-local ftp = require("ftp")
+local ftp = require("socket.ftp")
-- Log as user "fulano" on server "ftp.example.com",
-- using password "silva", and store a file "README" with contents
@@ -241,7 +241,7 @@ f, e = ftp.put("ftp://fulano:silva@ftp.example.com/README",
-- load the ftp support
-local ftp = require("ftp")
+local ftp = require("socket.ftp")
local ltn12 = require("ltn12")
-- Log as user "fulano" on server "ftp.example.com",
diff --git a/doc/http.html b/doc/http.html
index 4cbbe95..af58571 100644
--- a/doc/http.html
+++ b/doc/http.html
@@ -62,7 +62,7 @@ To obtain the http namespace, run:
-- loads the HTTP module and any libraries it requires
-local http = require("http")
+local http = require("socket.http")
-- load the http module
-http = require("http")
+http = require("socket.http")
-- connect to server "www.tecgraf.puc-rio.br" and retrieves this manual
-- file from "/luasocket/http.html"
@@ -231,7 +231,7 @@ And here is an example using the generic interface:
-- load the http module
-http = require("http")
+http = require("socket.http")
-- Requests information about a document, without downloading it.
-- Useful, for example, if you want to display a download gauge and need
@@ -276,7 +276,7 @@ authentication is required.
-- load required modules
-http = require("http")
+http = require("socket.http")
mime = require("mime")
-- Connect to server "www.example.com" and tries to retrieve
diff --git a/doc/introduction.html b/doc/introduction.html
index f8fe078..c88fa40 100644
--- a/doc/introduction.html
+++ b/doc/introduction.html
@@ -182,7 +182,7 @@ program.
-- load namespace
local socket = require("socket")
-- create a TCP socket and bind it to the local host, at any port
-local server = socket.try(socket.bind("*", 0))
+local server = assert(socket.bind("*", 0))
-- find out which port the OS chose for us
local ip, port = server:getsockname()
-- print a message informing what's up
@@ -287,13 +287,13 @@ local host, port = "localhost", 13
-- load namespace
local socket = require("socket")
-- convert host name to ip address
-local ip = socket.try(socket.dns.toip(host))
+local ip = assert(socket.dns.toip(host))
-- create a new UDP object
-local udp = socket.try(socket.udp())
+local udp = assert(socket.udp())
-- contact daytime host
-socket.try(udp:sendto("anything", ip, port))
+assert(udp:sendto("anything", ip, port))
-- retrieve the answer and print results
-io.write(socket.try((udp:receive())))
+io.write(assert(udp:receive()))
diff --git a/doc/ltn12.html b/doc/ltn12.html
index 44fcbe4..c5a0f59 100644
--- a/doc/ltn12.html
+++ b/doc/ltn12.html
@@ -271,7 +271,7 @@ The function returns the sink and the table used to store the chunks.
-- load needed modules
-local http = require("http")
+local http = require("socket.http")
local ltn12 = require("ltn12")
-- a simplified http.get function
diff --git a/doc/smtp.html b/doc/smtp.html
index 8feae3e..bd18bfa 100644
--- a/doc/smtp.html
+++ b/doc/smtp.html
@@ -69,7 +69,7 @@ To obtain the smtp namespace, run:
-- loads the SMTP module and everything it requires
-local smtp = require("smtp")
+local smtp = require("socket.smtp")
-- load the smtp support
-local smtp = require("smtp")
+local smtp = require("socket.smtp")
-- Connects to server "localhost" and sends a message to users
-- "fulano@example.com", "beltrano@example.com",
@@ -329,7 +329,7 @@ as listed in the introduction.
-- load the smtp support and its friends -local smtp = require("smtp") +local smtp = require("socket.smtp") local mime = require("mime") local ltn12 = require("ltn12") diff --git a/doc/socket.html b/doc/socket.html index f638fd9..18c71d1 100644 --- a/doc/socket.html +++ b/doc/socket.html @@ -145,7 +145,10 @@ socket.protect(func)-Converts a function that throws exceptions into a safe function. +Converts a function that throws exceptions into a safe function. This +function only catches exceptions thrown by the try +and newtry functions. It does not catch normal +Lua errors.
@@ -346,7 +349,9 @@ socket.try(ret1 [, ret2 ... retN])
-Throws an exception in case of error. +Throws an exception in case of error. The exception can only be caught +by the protect function. It does not explode +into an error message.
diff --git a/doc/url.html b/doc/url.html index 56e1ef5..ac84d24 100644 --- a/doc/url.html +++ b/doc/url.html @@ -52,7 +52,7 @@ To obtain the url namespace, run:
-- loads the URL module -local url = require("url") +local url = require("socket.url")@@ -193,7 +193,7 @@ The function returns the encoded string.
-- load url module -url = require("url") +url = require("socket.url") code = url.escape("/#?;") -- code = "%2f%23%3f%3b" @@ -239,7 +239,7 @@ parsed_url = {
-- load url module -url = require("url") +url = require("socket.url") parsed_url = url.parse("http://www.example.com/cgilua/index.lua?a=2#there") -- parsed_url = { diff --git a/src/buffer.c b/src/buffer.c index 0ec7b4d..45cd0f2 100644 --- a/src/buffer.c +++ b/src/buffer.c @@ -123,7 +123,7 @@ int buf_meth_receive(lua_State *L, p_buf buf) { else if (p[0] == '*' && p[1] == 'a') err = recvall(buf, &b); else luaL_argcheck(L, 0, 2, "invalid receive pattern"); /* get a fixed number of bytes */ - } else err = recvraw(buf, (size_t) lua_tonumber(L, 2), &b); + } else err = recvraw(buf, (size_t) lua_tonumber(L, 2)-size, &b); /* check if there was an error */ if (err != IO_DONE) { /* we can't push anyting in the stack before pushing the diff --git a/src/inet.c b/src/inet.c index e2afcdf..d713643 100644 --- a/src/inet.c +++ b/src/inet.c @@ -220,7 +220,6 @@ const char *inet_tryconnect(p_sock ps, const char *address, } } else remote.sin_family = AF_UNSPEC; err = sock_connect(ps, (SA *) &remote, sizeof(remote), tm); - if (err != IO_DONE) sock_destroy(ps); return sock_strerror(err); } diff --git a/src/luasocket.c b/src/luasocket.c index 4b829f8..8f13dbc 100644 --- a/src/luasocket.c +++ b/src/luasocket.c @@ -87,7 +87,7 @@ static int global_unload(lua_State *L) { static int base_open(lua_State *L) { if (sock_open()) { /* export functions (and leave namespace table on top of stack) */ - luaL_module(L, "socket", func, 0); + luaL_openlib(L, "socket", func, 0); #ifdef LUASOCKET_DEBUG lua_pushstring(L, "DEBUG"); lua_pushboolean(L, 1); @@ -108,7 +108,7 @@ static int base_open(lua_State *L) { /*-------------------------------------------------------------------------*\ * Initializes all library modules. \*-------------------------------------------------------------------------*/ -LUASOCKET_API int luaopen_lsocket(lua_State *L) { +LUASOCKET_API int luaopen_csocket(lua_State *L) { int i; base_open(L); for (i = 0; mod[i].name; i++) mod[i].func(L); diff --git a/src/luasocket.h b/src/luasocket.h index db54a18..768e335 100644 --- a/src/luasocket.h +++ b/src/luasocket.h @@ -13,7 +13,7 @@ /*-------------------------------------------------------------------------*\ * Current luasocket version \*-------------------------------------------------------------------------*/ -#define LUASOCKET_VERSION "LuaSocket 2.0 (beta3)" +#define LUASOCKET_VERSION "LuaSocket 2.0" #define LUASOCKET_COPYRIGHT "Copyright (C) 2004-2005 Diego Nehab" #define LUASOCKET_AUTHORS "Diego Nehab" @@ -27,6 +27,6 @@ /*-------------------------------------------------------------------------*\ * Initializes the library. \*-------------------------------------------------------------------------*/ -LUASOCKET_API int luaopen_socket(lua_State *L); +LUASOCKET_API int luaopen_csocket(lua_State *L); #endif /* LUASOCKET_H */ diff --git a/src/mime.c b/src/mime.c index dcc4af3..67f9f5b 100644 --- a/src/mime.c +++ b/src/mime.c @@ -78,9 +78,9 @@ static UC b64unbase[256]; /*-------------------------------------------------------------------------*\ * Initializes module \*-------------------------------------------------------------------------*/ -MIME_API int luaopen_lmime(lua_State *L) +MIME_API int luaopen_cmime(lua_State *L) { - luaL_module(L, "mime", func, 0); + luaL_openlib(L, "mime", func, 0); /* initialize lookup tables */ qpsetup(qpclass, qpunbase); b64setup(b64unbase); diff --git a/src/mime.h b/src/mime.h index 688d043..d596861 100644 --- a/src/mime.h +++ b/src/mime.h @@ -19,6 +19,6 @@ #define MIME_API extern #endif -MIME_API int luaopen_mime(lua_State *L); +MIME_API int luaopen_cmime(lua_State *L); #endif /* MIME_H */ diff --git a/src/mime.lua b/src/mime.lua index 4d5bdba..6492a96 100644 --- a/src/mime.lua +++ b/src/mime.lua @@ -8,9 +8,10 @@ ----------------------------------------------------------------------------- -- Declare module and import dependencies ----------------------------------------------------------------------------- +package.loaded.base = _G local base = require("base") local ltn12 = require("ltn12") -local mime = require("lmime") +local mime = require("cmime") module("mime") -- encode, decode and wrap algorithm tables diff --git a/src/socket.h b/src/socket.h index 368c2b6..639229d 100644 --- a/src/socket.h +++ b/src/socket.h @@ -45,11 +45,15 @@ int sock_sendto(p_sock ps, const char *data, size_t count, size_t *sent, SA *addr, socklen_t addr_len, p_tm tm); int sock_recvfrom(p_sock ps, char *data, size_t count, size_t *got, SA *addr, socklen_t *addr_len, p_tm tm); + void sock_setnonblocking(p_sock ps); void sock_setblocking(p_sock ps); + +int sock_waitfd(int fd, int sw, p_tm tm); int sock_select(int n, fd_set *rfds, fd_set *wfds, fd_set *efds, p_tm tm); int sock_connect(p_sock ps, SA *addr, socklen_t addr_len, p_tm tm); +int sock_connected(p_sock ps, p_tm tm); int sock_create(p_sock ps, int domain, int type, int protocol); int sock_bind(p_sock ps, SA *addr, socklen_t addr_len); int sock_listen(p_sock ps, int backlog); diff --git a/src/socket.lua b/src/socket.lua index 1c82750..f3563e7 100644 --- a/src/socket.lua +++ b/src/socket.lua @@ -7,10 +7,11 @@ ----------------------------------------------------------------------------- -- Declare module and import dependencies ----------------------------------------------------------------------------- +package.loaded.base = _G local base = require("base") local string = require("string") local math = require("math") -local socket = require("lsocket") +local socket = require("csocket") module("socket") ----------------------------------------------------------------------------- diff --git a/src/tcp.c b/src/tcp.c index 0b3706b..3a84191 100644 --- a/src/tcp.c +++ b/src/tcp.c @@ -20,6 +20,7 @@ \*=========================================================================*/ static int global_create(lua_State *L); static int meth_connect(lua_State *L); +static int meth_connected(lua_State *L); static int meth_listen(lua_State *L); static int meth_bind(lua_State *L); static int meth_send(lua_State *L); @@ -45,6 +46,7 @@ static luaL_reg tcp[] = { {"bind", meth_bind}, {"close", meth_close}, {"connect", meth_connect}, + {"connected", meth_connected}, {"dirty", meth_dirty}, {"getfd", meth_getfd}, {"getpeername", meth_getpeername}, @@ -113,12 +115,12 @@ static int meth_receive(lua_State *L) { } static int meth_getstats(lua_State *L) { - p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{any}", 1); + p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{client}", 1); return buf_meth_getstats(L, &tcp->buf); } static int meth_setstats(lua_State *L) { - p_tcp tcp = (p_tcp) aux_checkgroup(L, "tcp{any}", 1); + p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{client}", 1); return buf_meth_setstats(L, &tcp->buf); } @@ -224,6 +226,22 @@ static int meth_connect(lua_State *L) return 1; } +static int meth_connected(lua_State *L) +{ + static t_tm tm = {-1, -1}; + p_tcp tcp = (p_tcp) aux_checkclass(L, "tcp{master}", 1); + int err = sock_connected(&tcp->sock, &tm); + if (err != IO_DONE) { + lua_pushnil(L); + lua_pushstring(L, sock_strerror(err)); + return 2; + } + /* turn master object into a client object */ + aux_setclass(L, "tcp{client}", 1); + lua_pushnumber(L, 1); + return 1; +} + /*-------------------------------------------------------------------------*\ * Closes socket used by object \*-------------------------------------------------------------------------*/ diff --git a/src/unix.c b/src/unix.c index 1e0e252..c169268 100644 --- a/src/unix.c +++ b/src/unix.c @@ -32,6 +32,8 @@ static int meth_settimeout(lua_State *L); static int meth_getfd(lua_State *L); static int meth_setfd(lua_State *L); static int meth_dirty(lua_State *L); +static int meth_getstats(lua_State *L); +static int meth_setstats(lua_State *L); static const char *unix_tryconnect(p_unix un, const char *path); static const char *unix_trybind(p_unix un, const char *path); @@ -46,6 +48,8 @@ static luaL_reg un[] = { {"connect", meth_connect}, {"dirty", meth_dirty}, {"getfd", meth_getfd}, + {"getstats", meth_getstats}, + {"setstats", meth_setstats}, {"listen", meth_listen}, {"receive", meth_receive}, {"send", meth_send}, @@ -75,7 +79,7 @@ static luaL_reg func[] = { /*-------------------------------------------------------------------------*\ * Initializes module \*-------------------------------------------------------------------------*/ -int unix_open(lua_State *L) { +int luaopen_socketunix(lua_State *L) { /* create classes */ aux_newclass(L, "unix{master}", un); aux_newclass(L, "unix{client}", un); @@ -84,11 +88,9 @@ int unix_open(lua_State *L) { aux_add2group(L, "unix{master}", "unix{any}"); aux_add2group(L, "unix{client}", "unix{any}"); aux_add2group(L, "unix{server}", "unix{any}"); - aux_add2group(L, "unix{client}", "unix{client,server}"); - aux_add2group(L, "unix{server}", "unix{client,server}"); /* define library functions */ - luaL_openlib(L, NULL, func, 0); - return 0; + luaL_openlib(L, "socket", func, 0); + return 1; } /*=========================================================================*\ @@ -107,6 +109,16 @@ static int meth_receive(lua_State *L) { return buf_meth_receive(L, &un->buf); } +static int meth_getstats(lua_State *L) { + p_unix un = (p_unix) aux_checkclass(L, "unix{client}", 1); + return buf_meth_getstats(L, &un->buf); +} + +static int meth_setstats(lua_State *L) { + p_unix un = (p_unix) aux_checkclass(L, "unix{client}", 1); + return buf_meth_setstats(L, &un->buf); +} + /*-------------------------------------------------------------------------*\ * Just call option handler \*-------------------------------------------------------------------------*/ @@ -250,7 +262,8 @@ static int meth_close(lua_State *L) { p_unix un = (p_unix) aux_checkgroup(L, "unix{any}", 1); sock_destroy(&un->sock); - return 0; + lua_pushnumber(L, 1); + return 1; } /*-------------------------------------------------------------------------*\ @@ -277,7 +290,7 @@ static int meth_listen(lua_State *L) \*-------------------------------------------------------------------------*/ static int meth_shutdown(lua_State *L) { - p_unix un = (p_unix) aux_checkgroup(L, "unix{client}", 1); + p_unix un = (p_unix) aux_checkclass(L, "unix{client}", 1); const char *how = luaL_optstring(L, 2, "both"); switch (how[0]) { case 'b': diff --git a/src/unix.h b/src/unix.h index 7b2a5c5..aaaef3d 100644 --- a/src/unix.h +++ b/src/unix.h @@ -23,6 +23,6 @@ typedef struct t_unix_ { } t_unix; typedef t_unix *p_unix; -int unix_open(lua_State *L); +int luaopen_socketunix(lua_State *L); #endif /* UNIX_H */ diff --git a/src/usocket.c b/src/usocket.c index c1ab725..3428a0c 100644 --- a/src/usocket.c +++ b/src/usocket.c @@ -22,7 +22,7 @@ #define WAITFD_R POLLIN #define WAITFD_W POLLOUT #define WAITFD_C (POLLIN|POLLOUT) -static int sock_waitfd(int fd, int sw, p_tm tm) { +int sock_waitfd(int fd, int sw, p_tm tm) { int ret; struct pollfd pfd; pfd.fd = fd; @@ -44,7 +44,7 @@ static int sock_waitfd(int fd, int sw, p_tm tm) { #define WAITFD_W 2 #define WAITFD_C (WAITFD_R|WAITFD_W) -static int sock_waitfd(int fd, int sw, p_tm tm) { +int sock_waitfd(int fd, int sw, p_tm tm) { int ret; fd_set rfds, wfds, *rp, *wp; struct timeval tv, *tp; @@ -166,12 +166,20 @@ int sock_connect(p_sock ps, SA *addr, socklen_t len, p_tm tm) { while ((err = errno) == EINTR); /* if connection failed immediately, return error code */ if (err != EINPROGRESS && err != EAGAIN) return err; + /* zero timeout case optimization */ + if (tm_iszero(tm)) return IO_TIMEOUT; /* wait until we have the result of the connection attempt or timeout */ - if ((err = sock_waitfd(*ps, WAITFD_C, tm)) == IO_CLOSED) { - /* finaly find out if we succeeded connecting */ + return sock_connected(ps, tm); +} + +/*-------------------------------------------------------------------------*\ +* Checks if socket is connected, or return reason for failure +\*-------------------------------------------------------------------------*/ +int sock_connected(p_sock ps, p_tm tm) { + int err; + if ((err = sock_waitfd(*ps, WAITFD_C, tm) == IO_CLOSED)) { if (recv(*ps, (char *) &err, 0, 0) == 0) return IO_DONE; else return errno; - /* timed out or some weirder error */ } else return err; } @@ -321,13 +329,17 @@ void sock_setnonblocking(p_sock ps) { int sock_gethostbyaddr(const char *addr, socklen_t len, struct hostent **hp) { *hp = gethostbyaddr(addr, len, AF_INET); if (*hp) return IO_DONE; - else return h_errno; + else if (h_errno) return h_errno; + else if (errno) return errno; + else return IO_UNKNOWN; } int sock_gethostbyname(const char *addr, struct hostent **hp) { *hp = gethostbyname(addr); if (*hp) return IO_DONE; - else return h_errno; + else if (h_errno) return h_errno; + else if (errno) return errno; + else return IO_UNKNOWN; } /*-------------------------------------------------------------------------*\ diff --git a/src/wsocket.c b/src/wsocket.c index 69fac4d..c0686cd 100644 --- a/src/wsocket.c +++ b/src/wsocket.c @@ -45,7 +45,7 @@ int sock_close(void) { #define WAITFD_E 4 #define WAITFD_C (WAITFD_E|WAITFD_W) -static int sock_waitfd(t_sock fd, int sw, p_tm tm) { +int sock_waitfd(t_sock fd, int sw, p_tm tm) { int ret; fd_set rfds, wfds, efds, *rp = NULL, *wp = NULL, *ep = NULL; struct timeval tv, *tp = NULL; @@ -118,7 +118,17 @@ int sock_connect(p_sock ps, SA *addr, socklen_t len, p_tm tm) { /* make sure the system is trying to connect */ err = WSAGetLastError(); if (err != WSAEWOULDBLOCK && err != WSAEINPROGRESS) return err; + /* zero timeout case optimization */ + if (tm_iszero(tm)) return IO_TIMEOUT; /* we wait until something happens */ + return sock_connected(ps, tm); +} + +/*-------------------------------------------------------------------------*\ +* Check if socket is connected +\*-------------------------------------------------------------------------*/ +int sock_connected(p_sock ps) { + int err; if ((err = sock_waitfd(*ps, WAITFD_C, tm)) == IO_CLOSED) { int len = sizeof(err); /* give windows time to set the error (yes, disgusting) */ @@ -126,9 +136,8 @@ int sock_connect(p_sock ps, SA *addr, socklen_t len, p_tm tm) { /* find out why we failed */ getsockopt(*ps, SOL_SOCKET, SO_ERROR, (char *)&err, &len); /* we KNOW there was an error. if why is 0, we will return - * "unknown error", but it's not really our fault */ + * "unknown error", but it's not really our fault */ return err > 0? err: IO_UNKNOWN; - /* here we deal with the case in which it worked, timedout or weird errors */ } else return err; } diff --git a/test/dicttest.lua b/test/dicttest.lua index a37ec8d..7ac7811 100644 --- a/test/dicttest.lua +++ b/test/dicttest.lua @@ -1,3 +1,3 @@ local dict = require"socket.dict" -for i,v in dict.get("dict://dell-diego/d:banana") do print(v) end +for i,v in dict.get("dict://localhost/d:teste") do print(v) end diff --git a/test/httptest.lua b/test/httptest.lua index 8862ceb..2335fcb 100644 --- a/test/httptest.lua +++ b/test/httptest.lua @@ -23,7 +23,7 @@ http.TIMEOUT = 10 local t = socket.gettime() host = host or "diego.student.princeton.edu" -proxy = proxy or "http://dell-diego:3128" +proxy = proxy or "http://localhost:3128" prefix = prefix or "/luasocket-test" cgiprefix = cgiprefix or "/luasocket-test-cgi" index_file = "test/index.html" diff --git a/test/testclnt.lua b/test/testclnt.lua index c2c782c..e3f2b94 100644 --- a/test/testclnt.lua +++ b/test/testclnt.lua @@ -465,16 +465,14 @@ print("Testing " .. 2*size .. " bytes") remote(string.format([[ data:send(string.rep("a", %d)) socket.sleep(0.5) - data:send(string.rep("b", %d)) + data:send(string.rep("b", %d) .. "\n") ]], size, size)) local err = "timeout" local part = "" local str data:settimeout(0) while 1 do - local needed = 2*size - string.len(part) - assert(needed > 0, "weird") - str, err, part = data:receive(needed, part) + str, err, part = data:receive("*l", part) if err ~= "timeout" then break end end assert(str == (string.rep("a", size) .. string.rep("b", size))) @@ -482,15 +480,14 @@ remote(string.format([[ remote(string.format([[ str = data:receive(%d) socket.sleep(0.5) - str = data:receive(%d, str) + str = data:receive(2*%d, str) data:send(str) ]], size, size)) data:settimeout(0) - local sofar = 1 + local start = 0 while 1 do - _, err, part = data:send(str, sofar) + ret, err, start = data:send(str, start+1) if err ~= "timeout" then break end - sofar = sofar + part end data:send("\n") data:settimeout(-1) @@ -501,6 +498,7 @@ end ------------------------------------------------------------------------ + test("method registration") test_methods(socket.tcp(), { "accept", @@ -622,7 +620,7 @@ test_nonblocking(17) test_nonblocking(200) test_nonblocking(4091) test_nonblocking(80199) -test_nonblocking(8000000) +test_nonblocking(800000) test_nonblocking(80199) test_nonblocking(4091) test_nonblocking(200) diff --git a/test/testsrvr.lua b/test/testsrvr.lua index 2408e83..f1972c2 100644 --- a/test/testsrvr.lua +++ b/test/testsrvr.lua @@ -9,6 +9,7 @@ while 1 do while 1 do command = assert(control:receive()); assert(control:send(ack)); + print(command); (loadstring(command))(); end end diff --git a/test/utestclnt.lua b/test/utestclnt.lua new file mode 100644 index 0000000..f002c6e --- /dev/null +++ b/test/utestclnt.lua @@ -0,0 +1,644 @@ +require"socket" +local socket = require"socket.unix" + +host = "luasocket" + +function pass(...) + local s = string.format(unpack(arg)) + io.stderr:write(s, "\n") +end + +function fail(...) + local s = string.format(unpack(arg)) + io.stderr:write("ERROR: ", s, "!\n") +socket.sleep(3) + os.exit() +end + +function warn(...) + local s = string.format(unpack(arg)) + io.stderr:write("WARNING: ", s, "\n") +end + +function remote(...) + local s = string.format(unpack(arg)) + s = string.gsub(s, "\n", ";") + s = string.gsub(s, "%s+", " ") + s = string.gsub(s, "^%s*", "") + control:send(s .. "\n") + control:receive() +end + +function test(test) + io.stderr:write("----------------------------------------------\n", + "testing: ", test, "\n", + "----------------------------------------------\n") +end + +function uconnect(path) + local u = assert(socket.unix()) + assert(u:connect(path)) + return u +end + +function ubind(path) + local u = assert(socket.unix()) + assert(u:bind(path)) + assert(u:listen(5)) + return u +end + +function check_timeout(tm, sl, elapsed, err, opp, mode, alldone) + if tm < sl then + if opp == "send" then + if not err then warn("must be buffered") + elseif err == "timeout" then pass("proper timeout") + else fail("unexpected error '%s'", err) end + else + if err ~= "timeout" then fail("should have timed out") + else pass("proper timeout") end + end + else + if mode == "total" then + if elapsed > tm then + if err ~= "timeout" then fail("should have timed out") + else pass("proper timeout") end + elseif elapsed < tm then + if err then fail(err) + else pass("ok") end + else + if alldone then + if err then fail("unexpected error '%s'", err) + else pass("ok") end + else + if err ~= "timeout" then fail(err) + else pass("proper timeoutk") end + end + end + else + if err then fail(err) + else pass("ok") end + end + end +end + +if not socket.DEBUG then + fail("Please define LUASOCKET_DEBUG and recompile LuaSocket") +end + +io.stderr:write("----------------------------------------------\n", +"LuaSocket Test Procedures\n", +"----------------------------------------------\n") + +start = socket.gettime() + +function reconnect() + io.stderr:write("attempting data connection... ") + if data then data:close() end + remote [[ + i = i or 1 + if data then data:close() data = nil end + print("accepting") + data = server:accept() + i = i + 1 + print("done " .. i) + ]] + data, err = uconnect(host, port) + if not data then fail(err) + else pass("connected!") end +end + +pass("attempting control connection...") +control, err = uconnect(host, port) +if err then fail(err) +else pass("connected!") end + +------------------------------------------------------------------------ +function test_methods(sock, methods) + for _, v in methods do + if type(sock[v]) ~= "function" then + fail(sock.class .. " method '" .. v .. "' not registered") + end + end + pass(sock.class .. " methods are ok") +end + +------------------------------------------------------------------------ +function test_mixed(len) + reconnect() + local inter = math.ceil(len/4) + local p1 = "unix " .. string.rep("x", inter) .. "line\n" + local p2 = "dos " .. string.rep("y", inter) .. "line\r\n" + local p3 = "raw " .. string.rep("z", inter) .. "bytes" + local p4 = "end" .. string.rep("w", inter) .. "bytes" + local bp1, bp2, bp3, bp4 +remote (string.format("str = data:receive(%d)", + string.len(p1)+string.len(p2)+string.len(p3)+string.len(p4))) + sent, err = data:send(p1..p2..p3..p4) + if err then fail(err) end +remote "data:send(str); data:close()" + bp1, err = data:receive() + if err then fail(err) end + bp2, err = data:receive() + if err then fail(err) end + bp3, err = data:receive(string.len(p3)) + if err then fail(err) end + bp4, err = data:receive("*a") + if err then fail(err) end + if bp1.."\n" == p1 and bp2.."\r\n" == p2 and bp3 == p3 and bp4 == p4 then + pass("patterns match") + else fail("patterns don't match") end +end + +------------------------------------------------------------------------ +function test_asciiline(len) + reconnect() + local str, str10, back, err + str = string.rep("x", math.mod(len, 10)) + str10 = string.rep("aZb.c#dAe?", math.floor(len/10)) + str = str .. str10 +remote "str = data:receive()" + sent, err = data:send(str.."\n") + if err then fail(err) end +remote "data:send(str ..'\\n')" + back, err = data:receive() + if err then fail(err) end + if back == str then pass("lines match") + else fail("lines don't match") end +end + +------------------------------------------------------------------------ +function test_rawline(len) + reconnect() + local str, str10, back, err + str = string.rep(string.char(47), math.mod(len, 10)) + str10 = string.rep(string.char(120,21,77,4,5,0,7,36,44,100), + math.floor(len/10)) + str = str .. str10 +remote "str = data:receive()" + sent, err = data:send(str.."\n") + if err then fail(err) end +remote "data:send(str..'\\n')" + back, err = data:receive() + if err then fail(err) end + if back == str then pass("lines match") + else fail("lines don't match") end +end + +------------------------------------------------------------------------ +function test_raw(len) + reconnect() + local half = math.floor(len/2) + local s1, s2, back, err + s1 = string.rep("x", half) + s2 = string.rep("y", len-half) +remote (string.format("str = data:receive(%d)", len)) + sent, err = data:send(s1) + if err then fail(err) end + sent, err = data:send(s2) + if err then fail(err) end +remote "data:send(str)" + back, err = data:receive(len) + if err then fail(err) end + if back == s1..s2 then pass("blocks match") + else fail("blocks don't match") end +end + +------------------------------------------------------------------------ +function test_totaltimeoutreceive(len, tm, sl) + reconnect() + local str, err, partial + pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl) + remote (string.format ([[ + data:settimeout(%d) + str = string.rep('a', %d) + data:send(str) + print('server: sleeping for %ds') + socket.sleep(%d) + print('server: woke up') + data:send(str) + ]], 2*tm, len, sl, sl)) + data:settimeout(tm, "total") +local t = socket.gettime() + str, err, partial, elapsed = data:receive(2*len) + check_timeout(tm, sl, elapsed, err, "receive", "total", + string.len(str or partial) == 2*len) +end + +------------------------------------------------------------------------ +function test_totaltimeoutsend(len, tm, sl) + reconnect() + local str, err, total + pass("%d bytes, %ds total timeout, %ds pause", len, tm, sl) + remote (string.format ([[ + data:settimeout(%d) + str = data:receive(%d) + print('server: sleeping for %ds') + socket.sleep(%d) + print('server: woke up') + str = data:receive(%d) + ]], 2*tm, len, sl, sl, len)) + data:settimeout(tm, "total") + str = string.rep("a", 2*len) + total, err, partial, elapsed = data:send(str) + check_timeout(tm, sl, elapsed, err, "send", "total", + total == 2*len) +end + +------------------------------------------------------------------------ +function test_blockingtimeoutreceive(len, tm, sl) + reconnect() + local str, err, partial + pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl) + remote (string.format ([[ + data:settimeout(%d) + str = string.rep('a', %d) + data:send(str) + print('server: sleeping for %ds') + socket.sleep(%d) + print('server: woke up') + data:send(str) + ]], 2*tm, len, sl, sl)) + data:settimeout(tm) + str, err, partial, elapsed = data:receive(2*len) + check_timeout(tm, sl, elapsed, err, "receive", "blocking", + string.len(str or partial) == 2*len) +end + +------------------------------------------------------------------------ +function test_blockingtimeoutsend(len, tm, sl) + reconnect() + local str, err, total + pass("%d bytes, %ds blocking timeout, %ds pause", len, tm, sl) + remote (string.format ([[ + data:settimeout(%d) + str = data:receive(%d) + print('server: sleeping for %ds') + socket.sleep(%d) + print('server: woke up') + str = data:receive(%d) + ]], 2*tm, len, sl, sl, len)) + data:settimeout(tm) + str = string.rep("a", 2*len) + total, err, partial, elapsed = data:send(str) + check_timeout(tm, sl, elapsed, err, "send", "blocking", + total == 2*len) +end + +------------------------------------------------------------------------ +function empty_connect() + reconnect() + if data then data:close() data = nil end + remote [[ + if data then data:close() data = nil end + data = server:accept() + ]] + data, err = socket.connect("", port) + if not data then + pass("ok") + data = socket.connect(host, port) + else + pass("gethostbyname returns localhost on empty string...") + end +end + +------------------------------------------------------------------------ +function isclosed(c) + return c:getfd() == -1 or c:getfd() == (2^32-1) +end + +function active_close() + reconnect() + if isclosed(data) then fail("should not be closed") end + data:close() + if not isclosed(data) then fail("should be closed") end + data = nil + local udp = socket.udp() + if isclosed(udp) then fail("should not be closed") end + udp:close() + if not isclosed(udp) then fail("should be closed") end + pass("ok") +end + +------------------------------------------------------------------------ +function test_closed() + local back, partial, err + local str = 'little string' + reconnect() + pass("trying read detection") + remote (string.format ([[ + data:send('%s') + data:close() + data = nil + ]], str)) + -- try to get a line + back, err, partial = data:receive() + if not err then fail("should have gotten 'closed'.") + elseif err ~= "closed" then fail("got '"..err.."' instead of 'closed'.") + elseif str ~= partial then fail("didn't receive partial result.") + else pass("graceful 'closed' received") end + reconnect() + pass("trying write detection") + remote [[ + data:close() + data = nil + ]] + total, err, partial = data:send(string.rep("ugauga", 100000)) + if not err then + pass("failed: output buffer is at least %d bytes long!", total) + elseif err ~= "closed" then + fail("got '"..err.."' instead of 'closed'.") + else + pass("graceful 'closed' received after %d bytes were sent", partial) + end +end + +------------------------------------------------------------------------ +function test_selectbugs() + local r, s, e = socket.select(nil, nil, 0.1) + assert(type(r) == "table" and type(s) == "table" and + (e == "timeout" or e == "error")) + pass("both nil: ok") + local udp = socket.udp() + udp:close() + r, s, e = socket.select({ udp }, { udp }, 0.1) + assert(type(r) == "table" and type(s) == "table" and + (e == "timeout" or e == "error")) + pass("closed sockets: ok") + e = pcall(socket.select, "wrong", 1, 0.1) + assert(e == false) + e = pcall(socket.select, {}, 1, 0.1) + assert(e == false) + pass("invalid input: ok") +end + +------------------------------------------------------------------------ +function accept_timeout() + io.stderr:write("accept with timeout (if it hangs, it failed): ") + local s, e = socket.bind("*", 0, 0) + assert(s, e) + local t = socket.gettime() + s:settimeout(1) + local c, e = s:accept() + assert(not c, "should not accept") + assert(e == "timeout", string.format("wrong error message (%s)", e)) + t = socket.gettime() - t + assert(t < 2, string.format("took to long to give up (%gs)", t)) + s:close() + pass("good") +end + +------------------------------------------------------------------------ +function connect_timeout() + io.stderr:write("connect with timeout (if it hangs, it failed!): ") + local t = socket.gettime() + local c, e = socket.tcp() + assert(c, e) + c:settimeout(0.1) + local t = socket.gettime() + local r, e = c:connect("127.0.0.2", 80) + assert(not r, "should not connect") + assert(socket.gettime() - t < 2, "took too long to give up.") + c:close() + print("ok") +end + +------------------------------------------------------------------------ +function accept_errors() + io.stderr:write("not listening: ") + local d, e = socket.bind("*", 0) + assert(d, e); + local c, e = socket.tcp(); + assert(c, e); + d:setfd(c:getfd()) + d:settimeout(2) + local r, e = d:accept() + assert(not r and e) + print("ok: ", e) + io.stderr:write("not supported: ") + local c, e = socket.udp() + assert(c, e); + d:setfd(c:getfd()) + local r, e = d:accept() + assert(not r and e) + print("ok: ", e) +end + +------------------------------------------------------------------------ +function connect_errors() + io.stderr:write("connection refused: ") + local c, e = socket.connect("localhost", 1); + assert(not c and e) + print("ok: ", e) + io.stderr:write("host not found: ") + local c, e = socket.connect("host.is.invalid", 1); + assert(not c and e, e) + print("ok: ", e) +end + +------------------------------------------------------------------------ +function rebind_test() + local c = socket.bind("localhost", 0) + local i, p = c:getsockname() + local s, e = socket.tcp() + assert(s, e) + s:setoption("reuseaddr", false) + r, e = s:bind("localhost", p) + assert(not r, "managed to rebind!") + assert(e) + print("ok: ", e) +end + +------------------------------------------------------------------------ +function getstats_test() + reconnect() + local t = 0 + for i = 1, 25 do + local c = math.random(1, 100) + remote (string.format ([[ + str = data:receive(%d) + data:send(str) + ]], c)) + data:send(string.rep("a", c)) + data:receive(c) + t = t + c + local r, s, a = data:getstats() + assert(r == t, "received count failed" .. tostring(r) + .. "/" .. tostring(t)) + assert(s == t, "sent count failed" .. tostring(s) + .. "/" .. tostring(t)) + end + print("ok") +end + + +------------------------------------------------------------------------ +function test_nonblocking(size) + reconnect() +print("Testing " .. 2*size .. " bytes") +remote(string.format([[ + data:send(string.rep("a", %d)) + socket.sleep(0.5) + data:send(string.rep("b", %d) .. "\n") +]], size, size)) + local err = "timeout" + local part = "" + local str + data:settimeout(0) + while 1 do + str, err, part = data:receive("*l", part) + if err ~= "timeout" then break end + end + assert(str == (string.rep("a", size) .. string.rep("b", size))) + reconnect() +remote(string.format([[ + str = data:receive(%d) + socket.sleep(0.5) + str = data:receive(%d, str) + data:send(str) +]], size, size)) + data:settimeout(0) + local start = 0 + while 1 do + ret, err, start = data:send(str, start+1) + if err ~= "timeout" then break end + end + data:send("\n") + data:settimeout(-1) + local back = data:receive(2*size) + assert(back == str, "'" .. back .. "' vs '" .. str .. "'") + print("ok") +end + +------------------------------------------------------------------------ + +test("method registration") +test_methods(socket.unix(), { + "accept", + "bind", + "close", + "connect", + "dirty", + "getfd", + "getstats", + "setstats", + "listen", + "receive", + "send", + "setfd", + "setoption", + "setpeername", + "setsockname", + "settimeout", + "shutdown", +}) + +test("connect function") +--connect_timeout() +--empty_connect() +--connect_errors() + +--test("rebinding: ") +--rebind_test() + +test("active close: ") +active_close() + +test("closed connection detection: ") +test_closed() + +test("accept function: ") +accept_timeout() +accept_errors() + +test("getstats test") +getstats_test() + +test("character line") +test_asciiline(1) +test_asciiline(17) +test_asciiline(200) +test_asciiline(4091) +test_asciiline(80199) +test_asciiline(8000000) +test_asciiline(80199) +test_asciiline(4091) +test_asciiline(200) +test_asciiline(17) +test_asciiline(1) + +test("mixed patterns") +test_mixed(1) +test_mixed(17) +test_mixed(200) +test_mixed(4091) +test_mixed(801990) +test_mixed(4091) +test_mixed(200) +test_mixed(17) +test_mixed(1) + +test("binary line") +test_rawline(1) +test_rawline(17) +test_rawline(200) +test_rawline(4091) +test_rawline(80199) +test_rawline(8000000) +test_rawline(80199) +test_rawline(4091) +test_rawline(200) +test_rawline(17) +test_rawline(1) + +test("raw transfer") +test_raw(1) +test_raw(17) +test_raw(200) +test_raw(4091) +test_raw(80199) +test_raw(8000000) +test_raw(80199) +test_raw(4091) +test_raw(200) +test_raw(17) +test_raw(1) + +test("non-blocking transfer") +test_nonblocking(1) +test_nonblocking(17) +test_nonblocking(200) +test_nonblocking(4091) +test_nonblocking(80199) +test_nonblocking(8000000) +test_nonblocking(80199) +test_nonblocking(4091) +test_nonblocking(200) +test_nonblocking(17) +test_nonblocking(1) + +test("total timeout on send") +test_totaltimeoutsend(800091, 1, 3) +test_totaltimeoutsend(800091, 2, 3) +test_totaltimeoutsend(800091, 5, 2) +test_totaltimeoutsend(800091, 3, 1) + +test("total timeout on receive") +test_totaltimeoutreceive(800091, 1, 3) +test_totaltimeoutreceive(800091, 2, 3) +test_totaltimeoutreceive(800091, 3, 2) +test_totaltimeoutreceive(800091, 3, 1) + +test("blocking timeout on send") +test_blockingtimeoutsend(800091, 1, 3) +test_blockingtimeoutsend(800091, 2, 3) +test_blockingtimeoutsend(800091, 3, 2) +test_blockingtimeoutsend(800091, 3, 1) + +test("blocking timeout on receive") +test_blockingtimeoutreceive(800091, 1, 3) +test_blockingtimeoutreceive(800091, 2, 3) +test_blockingtimeoutreceive(800091, 3, 2) +test_blockingtimeoutreceive(800091, 3, 1) + +test(string.format("done in %.2fs", socket.gettime() - start)) diff --git a/test/utestsrvr.lua b/test/utestsrvr.lua new file mode 100644 index 0000000..f7be196 --- /dev/null +++ b/test/utestsrvr.lua @@ -0,0 +1,17 @@ +require("socket"); +os.remove("/tmp/luasocket") +socket = require("socket.unix"); +host = "luasocket"; +server = socket.unix() +print(server:bind(host)) +print(server:listen(5)) +ack = "\n"; +while 1 do + print("server: waiting for client connection..."); + control = assert(server:accept()); + while 1 do + command = assert(control:receive()); + assert(control:send(ack)); + (loadstring(command))(); + end +end