From 156669c28bc62bfddbd5625c4bb4c1f8da94802b Mon Sep 17 00:00:00 2001 From: Sam Roberts Date: Thu, 10 May 2012 14:14:22 -0700 Subject: [PATCH] socket.connect now implemented in the C core This avoid socket.lua duplicating the iteration over the results of getaddrinfo(). Some problems with the C implementation not initializing sockets or the luasocket family have also been fixed, and error reporting made more robust. --- doc/reference.html | 2 ++ doc/socket.html | 11 +++++++--- src/inet.c | 54 +++++++++++++++++++++++++++++++++++----------- src/inet.h | 3 +++ src/socket.lua | 34 ++++++----------------------- src/tcp.c | 11 ++++++---- src/usocket.c | 2 +- 7 files changed, 69 insertions(+), 48 deletions(-) diff --git a/doc/reference.html b/doc/reference.html index f069d47..e9bb5eb 100644 --- a/doc/reference.html +++ b/doc/reference.html @@ -145,6 +145,8 @@ Support, Manual">
bind, connect, +connect4, +connect6, _DEBUG, dns, gettime, diff --git a/doc/socket.html b/doc/socket.html index dcf8b61..b9303cb 100644 --- a/doc/socket.html +++ b/doc/socket.html @@ -73,14 +73,19 @@ set to true.

-socket.connect(address, port [, locaddr, locport]) +socket.connect[46](address, port [, locaddr] [, locport] [, family])

This function is a shortcut that creates and returns a TCP client object -connected to a remote host at a given port. Optionally, +connected to a remote address at a given port. Optionally, the user can also specify the local address and port to bind -(locaddr and locport). +(locaddr and locport), or restrict the socket family +to "inet" or "inet6". +Without specifying family to connect, whether a tcp or tcp6 +connection is created depends on your system configuration. Two variations +of connect are defined as simple helper functions that restrict the +family, socket.connect4 and socket.connect6.

diff --git a/src/inet.c b/src/inet.c index 55e89d7..e769cd8 100644 --- a/src/inet.c +++ b/src/inet.c @@ -143,6 +143,22 @@ static int inet_global_toip(lua_State *L) return 2; } +int inet_optfamily(lua_State* L, int narg, const char* def) +{ + static const char* optname[] = { "unspec", "inet", "inet6", NULL }; + static int optvalue[] = { PF_UNSPEC, PF_INET, PF_INET6, 0 }; + + return optvalue[luaL_checkoption(L, narg, def, optname)]; +} + +int inet_optsocktype(lua_State* L, int narg, const char* def) +{ + static const char* optname[] = { "stream", "dgram", NULL }; + static int optvalue[] = { SOCK_STREAM, SOCK_DGRAM, 0 }; + + return optvalue[luaL_checkoption(L, narg, def, optname)]; +} + static int inet_global_getaddrinfo(lua_State *L) { const char *hostname = luaL_checkstring(L, 1); @@ -197,7 +213,7 @@ static int inet_global_gethostname(lua_State *L) name[256] = '\0'; if (gethostname(name, 256) < 0) { lua_pushnil(L); - lua_pushstring(L, "gethostname failed"); + lua_pushstring(L, socket_strerror(errno)); return 2; } else { lua_pushstring(L, name); @@ -222,7 +238,7 @@ int inet_meth_getpeername(lua_State *L, p_socket ps, int family) char name[INET_ADDRSTRLEN]; if (getpeername(*ps, (SA *) &peer, &peer_len) < 0) { lua_pushnil(L); - lua_pushstring(L, "getpeername failed"); + lua_pushstring(L, socket_strerror(errno)); return 2; } else { inet_ntop(family, &peer.sin_addr, name, sizeof(name)); @@ -238,7 +254,7 @@ int inet_meth_getpeername(lua_State *L, p_socket ps, int family) char name[INET6_ADDRSTRLEN]; if (getpeername(*ps, (SA *) &peer, &peer_len) < 0) { lua_pushnil(L); - lua_pushstring(L, "getpeername failed"); + lua_pushstring(L, socket_strerror(errno)); return 2; } else { inet_ntop(family, &peer.sin6_addr, name, sizeof(name)); @@ -251,7 +267,7 @@ int inet_meth_getpeername(lua_State *L, p_socket ps, int family) } default: lua_pushnil(L); - lua_pushstring(L, "unknown family"); + lua_pushfstring(L, "unknown family %d", family); return 2; } } @@ -268,7 +284,7 @@ int inet_meth_getsockname(lua_State *L, p_socket ps, int family) char name[INET_ADDRSTRLEN]; if (getsockname(*ps, (SA *) &local, &local_len) < 0) { lua_pushnil(L); - lua_pushstring(L, "getsockname failed"); + lua_pushstring(L, socket_strerror(errno)); return 2; } else { inet_ntop(family, &local.sin_addr, name, sizeof(name)); @@ -284,7 +300,7 @@ int inet_meth_getsockname(lua_State *L, p_socket ps, int family) char name[INET6_ADDRSTRLEN]; if (getsockname(*ps, (SA *) &local, &local_len) < 0) { lua_pushnil(L); - lua_pushstring(L, "getsockname failed"); + lua_pushstring(L, socket_strerror(errno)); return 2; } else { inet_ntop(family, &local.sin6_addr, name, sizeof(name)); @@ -296,7 +312,7 @@ int inet_meth_getsockname(lua_State *L, p_socket ps, int family) } default: lua_pushnil(L); - lua_pushstring(L, "unknown family"); + lua_pushfstring(L, "unknown family %d", family); return 2; } } @@ -390,6 +406,7 @@ const char *inet_trybind(p_socket ps, const char *address, const char *serv, { struct addrinfo *iterator = NULL, *resolved = NULL; const char *err = NULL; + t_socket sock = *ps; /* translate luasocket special values to C */ if (strcmp(address, "*") == 0) address = NULL; if (!serv) serv = "0"; @@ -402,17 +419,30 @@ const char *inet_trybind(p_socket ps, const char *address, const char *serv, } /* iterate over resolved addresses until one is good */ for (iterator = resolved; iterator; iterator = iterator->ai_next) { + if(sock == SOCKET_INVALID) { + err = socket_strerror( socket_create(&sock, iterator->ai_family, + iterator->ai_socktype, iterator->ai_protocol)); + if(err) + continue; + } /* try binding to local address */ - err = socket_strerror(socket_bind(ps, + err = socket_strerror(socket_bind(&sock, (SA *) iterator->ai_addr, iterator->ai_addrlen)); - /* if faiiled, we try the next one */ - if (err != NULL) socket_destroy(ps); - /* if success, we abort loop */ - else break; + + /* keep trying unless bind succeeded */ + if (err) { + if(sock != *ps) + socket_destroy(&sock); + } else { + /* remember what we connected to, particularly the family */ + *bindhints = *iterator; + break; + } } /* cleanup and return error */ freeaddrinfo(resolved); + *ps = sock; return err; } diff --git a/src/inet.h b/src/inet.h index 1cbe4f4..05633bb 100644 --- a/src/inet.h +++ b/src/inet.h @@ -33,6 +33,9 @@ const char *inet_trybind(p_socket ps, const char *address, const char *serv, int inet_meth_getpeername(lua_State *L, p_socket ps, int family); int inet_meth_getsockname(lua_State *L, p_socket ps, int family); +int inet_optfamily(lua_State* L, int narg, const char* def); +int inet_optsocktype(lua_State* L, int narg, const char* def); + #ifdef INET_ATON int inet_aton(const char *cp, struct in_addr *inp); #endif diff --git a/src/socket.lua b/src/socket.lua index 8c5f231..e8def75 100644 --- a/src/socket.lua +++ b/src/socket.lua @@ -15,34 +15,12 @@ module("socket") ----------------------------------------------------------------------------- -- Exported auxiliar functions ----------------------------------------------------------------------------- -function connect(address, port, laddress, lport) - if address == "*" then address = "0.0.0.0" end - local addrinfo, err = socket.dns.getaddrinfo(address); - if not addrinfo then return nil, err end - local sock, res - err = "no info on address" - for i, alt in base.ipairs(addrinfo) do - if alt.family == "inet" then - sock, err = socket.tcp() - else - sock, err = socket.tcp6() - end - if not sock then return nil, err end - if laddress then - res, err = sock:bind(laddress, lport) - if not res then - sock:close() - return nil, err - end - end - res, err = sock:connect(alt.addr, port) - if not res then - sock:close() - else - return sock - end - end - return nil, err +function connect4(address, port, laddress, lport) + return socket.connect(address, port, laddress, lport, "inet") +end + +function connect6(address, port, laddress, lport) + return socket.connect(address, port, laddress, lport, "inet6") end function bind(host, port, backlog) diff --git a/src/tcp.c b/src/tcp.c index 94148c5..3a7f527 100644 --- a/src/tcp.c +++ b/src/tcp.c @@ -18,7 +18,7 @@ \*=========================================================================*/ static int global_create(lua_State *L); static int global_create6(lua_State *L); -static int global_connect6(lua_State *L); +static int global_connect(lua_State *L); static int meth_connect(lua_State *L); static int meth_listen(lua_State *L); static int meth_getfamily(lua_State *L); @@ -89,7 +89,7 @@ static t_opt optset[] = { static luaL_Reg func[] = { {"tcp", global_create}, {"tcp6", global_create6}, - {"connect6", global_connect6}, + {"connect", global_connect}, {NULL, NULL} }; @@ -408,6 +408,7 @@ static const char *tryconnect6(const char *remoteaddr, const char *remoteserv, freeaddrinfo(resolved); return err; } + tcp->family = iterator->ai_family; /* all sockets initially non-blocking */ socket_setnonblocking(&tcp->sock); } @@ -424,11 +425,12 @@ static const char *tryconnect6(const char *remoteaddr, const char *remoteserv, return err; } -static int global_connect6(lua_State *L) { +static int global_connect(lua_State *L) { const char *remoteaddr = luaL_checkstring(L, 1); const char *remoteserv = luaL_checkstring(L, 2); const char *localaddr = luaL_optstring(L, 3, NULL); const char *localserv = luaL_optstring(L, 4, "0"); + int family = inet_optfamily(L, 5, "unspec"); p_tcp tcp = (p_tcp) lua_newuserdata(L, sizeof(t_tcp)); struct addrinfo bindhints, connecthints; const char *err = NULL; @@ -441,7 +443,7 @@ static int global_connect6(lua_State *L) { /* allow user to pick local address and port */ memset(&bindhints, 0, sizeof(bindhints)); bindhints.ai_socktype = SOCK_STREAM; - bindhints.ai_family = PF_UNSPEC; + bindhints.ai_family = family; bindhints.ai_flags = AI_PASSIVE; if (localaddr) { err = inet_trybind(&tcp->sock, localaddr, localserv, &bindhints); @@ -450,6 +452,7 @@ static int global_connect6(lua_State *L) { lua_pushstring(L, err); return 2; } + tcp->family = bindhints.ai_family; } /* try to connect to remote address and port */ memset(&connecthints, 0, sizeof(connecthints)); diff --git a/src/usocket.c b/src/usocket.c index a168bf6..7150996 100644 --- a/src/usocket.c +++ b/src/usocket.c @@ -447,7 +447,7 @@ const char *socket_gaistrerror(int err) { case EAI_SERVICE: return "service not supported for socket type"; case EAI_SOCKTYPE: return "ai_socktype not supported"; case EAI_SYSTEM: return strerror(errno); - default: return "unknown error"; + default: return gai_strerror(err); } }