diff --git a/doc/curl.ldoc b/doc/curl.ldoc index cd54ebc..a915d4c 100644 --- a/doc/curl.ldoc +++ b/doc/curl.ldoc @@ -135,6 +135,17 @@ do -- e:perfom{writefunction = assert(io.open("fname.txt", "w+b"))} function perfom() end +--- User data. +-- +-- Please use this field to associate any data with curl handle. +-- +-- @field data +-- +-- @usage +-- f = io.open("lua.org.download", "w+") +-- e = curl.easy{url = "http://lua.org", writefunction = f} +-- e.data = f + end --- Muli curl object @@ -155,4 +166,10 @@ do -- for data, type, easy in m:iperform() do ... end function iperform() end +--- User data. +-- +-- Please use this field to associate any data with curl handle. +-- +-- @field data + end diff --git a/examples/cURLv3/multi-uv.lua b/examples/cURLv3/multi-uv.lua index 223b956..4638765 100644 --- a/examples/cURLv3/multi-uv.lua +++ b/examples/cURLv3/multi-uv.lua @@ -27,9 +27,11 @@ local FLAGS = { } -local trace = function() end or print +local trace = true -local FILES, CONTEXT = {}, {} +trace = trace and print or function() end + +local CONTEXT = {} function create_curl_context(sockfd) local context = { @@ -58,7 +60,7 @@ function add_download(url, num) writefunction = file; } - FILES[handle] = file + handle.data = file curl_handle:add_handle(handle) fprintf(stderr, "Added download %s -> %s\n", url, filename); @@ -66,14 +68,14 @@ end function check_multi_info() while true do - local easy, ok, err = curl_handle:info_read() + local easy, ok, err = curl_handle:info_read(true) if not easy then curl_handle:close() error(err) end if easy == 0 then break end local context = CONTEXT[e] if context then destroy_curl_context(context) end - local file = FILES[easy] - if file then FILES[easy] = nil, file:close() end + local file = assert(easy.data) + file:close() local done_url = easy:getinfo_effective_url() easy:close() if ok then @@ -116,31 +118,28 @@ function start_timeout(timeout_ms) timeout:stop():start(timeout_ms, 0, on_timeout) end -local handle_socket = function(...) - local ok, err = pcall(handle_socket_impl, ...) +function handle_socket(easy, s, action) + local ok, err = pcall(function() + -- calls by curl -- + trace("CURL::SOCKET", easy, s, ACTION_NAMES[action] or action) + + local curl_context = CONTEXT[easy] or create_curl_context(s) + CONTEXT[easy] = curl_context + + assert(curl_context.sockfd == s) + + if action == curl.POLL_IN then + curl_context.poll_handle:start(uv.READABLE, curl_perform) + elseif action == curl.POLL_OUT then + curl_context.poll_handle:start(uv.WRITABLE, curl_perform) + elseif action == curl.POLL_REMOVE then + CONTEXT[easy] = nil + destroy_curl_context(curl_context) + end + end) if not ok then uv.defer(function() error(err) end) end end -function handle_socket_impl(easy, s, action) - -- calls by curl -- - - trace("CURL::SOCKET", easy, s, ACTION_NAMES[action] or action) - - local curl_context = CONTEXT[easy] or create_curl_context(s) - CONTEXT[easy] = curl_context - - assert(curl_context.sockfd == s) - - if action == curl.POLL_IN then - curl_context.poll_handle:start(uv.READABLE, curl_perform) - elseif action == curl.POLL_OUT then - curl_context.poll_handle:start(uv.WRITABLE, curl_perform) - elseif action == curl.POLL_REMOVE then - CONTEXT[easy] = nil - destroy_curl_context(curl_context) - end -end - timeout = uv.timer() curl_handle = curl.multi{ diff --git a/src/lceasy.c b/src/lceasy.c index 48a35f4..2819457 100644 --- a/src/lceasy.c +++ b/src/lceasy.c @@ -100,6 +100,10 @@ static int lcurl_easy_cleanup(lua_State *L){ p->lists[i] = LUA_NOREF; } + lua_settop(L, 1); + lua_pushnil(L); + lua_rawset(L, LCURL_USERVALUES); + return 0; } @@ -941,6 +945,22 @@ static int lcurl_easy_pause(lua_State *L){ return 1; } +static int lcurl_easy_setdata(lua_State *L){ + lcurl_easy_t *p = lcurl_geteasy(L); + lua_settop(L, 2); + lua_pushvalue(L, 1); + lua_insert(L, 2); + lua_rawset(L, LCURL_USERVALUES); + return 1; +} + +static int lcurl_easy_getdata(lua_State *L){ + lcurl_easy_t *p = lcurl_geteasy(L); + lua_settop(L, 1); + lua_rawget(L, LCURL_USERVALUES); + return 1; +} + //} static const struct luaL_Reg lcurl_easy_methods[] = { @@ -982,6 +1002,9 @@ static const struct luaL_Reg lcurl_easy_methods[] = { { "close", lcurl_easy_cleanup }, { "__gc", lcurl_easy_cleanup }, + { "setdata", lcurl_easy_setdata }, + { "getdata", lcurl_easy_getdata }, + {NULL,NULL} }; diff --git a/src/lcmulti.c b/src/lcmulti.c index 7d3fd52..588507e 100644 --- a/src/lcmulti.c +++ b/src/lcmulti.c @@ -78,6 +78,11 @@ static int lcurl_multi_cleanup(lua_State *L){ luaL_unref(L, LCURL_LUA_REGISTRY, p->sc.ud_ref); p->tm.cb_ref = p->tm.ud_ref = LUA_NOREF; p->sc.cb_ref = p->sc.ud_ref = LUA_NOREF; + + lua_settop(L, 1); + lua_pushnil(L); + lua_rawset(L, LCURL_USERVALUES); + return 0; } @@ -441,6 +446,22 @@ static int lcurl_multi_setopt(lua_State *L){ return lcurl_fail_ex(L, p->err_mode, LCURL_ERROR_MULTI, CURLM_UNKNOWN_OPTION); } +static int lcurl_multi_setdata(lua_State *L){ + lcurl_multi_t *p = lcurl_getmulti(L); + lua_settop(L, 2); + lua_pushvalue(L, 1); + lua_insert(L, 2); + lua_rawset(L, LCURL_USERVALUES); + return 1; +} + +static int lcurl_multi_getdata(lua_State *L){ + lcurl_multi_t *p = lcurl_getmulti(L); + lua_settop(L, 1); + lua_rawget(L, LCURL_USERVALUES); + return 1; +} + //} static const struct luaL_Reg lcurl_multi_methods[] = { @@ -459,6 +480,9 @@ static const struct luaL_Reg lcurl_multi_methods[] = { OPT_ENTRY(socketfunction, SOCKETFUNCTION, TTT, 0) #undef OPT_ENTRY + { "setdata", lcurl_multi_setdata }, + { "getdata", lcurl_multi_getdata }, + {"close", lcurl_multi_cleanup }, {"__gc", lcurl_multi_cleanup }, diff --git a/src/lcurl.c b/src/lcurl.c index 2ca2e12..ff6d986 100644 --- a/src/lcurl.c +++ b/src/lcurl.c @@ -180,6 +180,7 @@ static const lcurl_const_t lcurl_flags[] = { static volatile int LCURL_INIT = 0; static const char* LCURL_REGISTRY = "LCURL Registry"; +static const char* LCURL_USERVAL = "LCURL Uservalues"; static int luaopen_lcurl_(lua_State *L, const struct luaL_Reg *func){ if(!LCURL_INIT){ @@ -192,16 +193,24 @@ static int luaopen_lcurl_(lua_State *L, const struct luaL_Reg *func){ lua_pop(L, 1); lua_newtable(L); } + + lua_rawgetp(L, LUA_REGISTRYINDEX, LCURL_USERVAL); + if(!lua_istable(L, -1)){ /* usevalues */ + lua_pop(L, 1); + lcurl_util_new_weak_table(L, "k"); + } + lua_newtable(L); /* library */ - lua_pushvalue(L, -2); luaL_setfuncs(L, func, 1); - lua_pushvalue(L, -2); lcurl_error_initlib(L, 1); - lua_pushvalue(L, -2); lcurl_hpost_initlib(L, 1); - lua_pushvalue(L, -2); lcurl_easy_initlib (L, 1); - lua_pushvalue(L, -2); lcurl_multi_initlib(L, 1); - lua_pushvalue(L, -2); lcurl_share_initlib(L, 1); + lua_pushvalue(L, -3); lua_pushvalue(L, -3); luaL_setfuncs(L, func, 2); + lua_pushvalue(L, -3); lua_pushvalue(L, -3); lcurl_error_initlib(L, 2); + lua_pushvalue(L, -3); lua_pushvalue(L, -3); lcurl_hpost_initlib(L, 2); + lua_pushvalue(L, -3); lua_pushvalue(L, -3); lcurl_easy_initlib (L, 2); + lua_pushvalue(L, -3); lua_pushvalue(L, -3); lcurl_multi_initlib(L, 2); + lua_pushvalue(L, -3); lua_pushvalue(L, -3); lcurl_share_initlib(L, 2); - lua_pushvalue(L, -2); lua_rawsetp(L, LUA_REGISTRYINDEX, LCURL_REGISTRY); + lua_pushvalue(L, -3); lua_rawsetp(L, LUA_REGISTRYINDEX, LCURL_REGISTRY); + lua_pushvalue(L, -2); lua_rawsetp(L, LUA_REGISTRYINDEX, LCURL_USERVAL); lua_remove(L, -2); /* registry */ diff --git a/src/lcurl.h b/src/lcurl.h index 1aa7478..b9bd48c 100644 --- a/src/lcurl.h +++ b/src/lcurl.h @@ -23,5 +23,6 @@ #define LCURL_LUA_REGISTRY lua_upvalueindex(1) +#define LCURL_USERVALUES lua_upvalueindex(2) #endif diff --git a/src/lua/cURL/impl/cURL.lua b/src/lua/cURL/impl/cURL.lua index a944668..403aa2b 100644 --- a/src/lua/cURL/impl/cURL.lua +++ b/src/lua/cURL/impl/cURL.lua @@ -99,9 +99,8 @@ local function make_iterator(self, perform) if n <= remain then while true do - local h, ok, err = assert(self:info_read()) - if h == 0 then break end - local e = assert(self._easy[h]) + local e, ok, err = assert(self:info_read()) + if e == 0 then break end if ok then ok = e:getinfo_response_code() or ok buffers:append(e, "done", ok) @@ -423,6 +422,22 @@ function Multi:remove_handle(e) return remove_handle(self, h) end +function Multi:info_read(...) + while true do + local h, ok, err = self:handle():info_read(...) + if not h then return nil, ok end + if h == 0 then return h end + + local e = self._easy[h] + if e then + if ... then + self._easy[h], self._easy.n = nil, self._easy.n - 1 + end + return e, ok, err + end + end +end + end ------------------------------------------- @@ -556,14 +571,74 @@ function Multi:remove_handle(e) end function Multi:info_read(...) - local h, ok, err = self:handle():info_read(...) - if not h then return nil, ok end - if h == 0 then return h end + while true do + local h, ok, err = self:handle():info_read(...) + if not h then return nil, ok end + if h == 0 then return h end - if ... and self._easy[h] then - self._easy[h], self._easy.n = nil, self._easy.n - 1 + local e = self._easy[h] + if e then + if ... then + self._easy[h], self._easy.n = nil, self._easy.n - 1 + end + return e, ok, err + end end - return h, ok, err +end + +function wrap_callback(...) + local n = select("#", ...) + local fn, ctx, has_ctx + if n >= 2 then + has_ctx, fn, ctx = true, assert(...) + else + fn = assert(...) + if type(fn) ~= "function" then + has_ctx, fn, ctx = true, assert(fn.socket), fn + end + end + if has_ctx then + return function(...) return fn(ctx, ...) end + end + return function(...) return fn(...) end +end + +function wrap_socketfunction(self, cb) + return function(h, ...) + local e = self._easy[h] + if e then return cb(e, ...) end + return 0 + end +end + +local setopt_socketfunction = wrap_function("setopt_socketfunction") +function Multi:setopt_socketfunction(...) + local cb = wrap_callback(...) + + return setopt_socketfunction(wrap_socketfunction(self, cb)) +end + +local setopt = wrap_function("setopt") +function Multi:setopt(k, v) + if type(k) == 'table' then + local t = k + + local socketfunction = t.socketfunction or t[curl.OPT_SOCKETFUNCTION] + if socketfunction then + t = clone(t) + local fn = wrap_socketfunction(self, socketfunction) + if t.socketfunction then t.socketfunction = fn end + if t[curl.OPT_SOCKETFUNCTION] then t[curl.OPT_SOCKETFUNCTION] = fn end + end + + return setopt(self, t) + end + + if k == curl.OPT_SOCKETFUNCTION then + return self:setopt_httppost(wrap_socketfunction(v)) + end + + return setopt(self, k, v) end end diff --git a/test/test_curl.lua b/test/test_curl.lua index a8864ff..60fa3b6 100644 --- a/test/test_curl.lua +++ b/test/test_curl.lua @@ -93,6 +93,20 @@ function test_add_handle() assert_nil(c) end +function test_info_read() + local url = 'http://httpbin.org/get?key=1' + c = assert(curl.easy{url=url, writefunction=function() end}) + assert_equal(m, m:add_handle(c)) + + while m:perform() > 0 do m:wait() end + + local h, ok, err = m:info_read() + assert_equal(c, h) + + local h, ok, err = m:info_read() + assert_equal(0, h) +end + end local _ENV = TEST_CASE'form' if ENABLE then diff --git a/test/test_easy.lua b/test/test_easy.lua index a4dfea9..3ef22cf 100644 --- a/test/test_easy.lua +++ b/test/test_easy.lua @@ -796,4 +796,36 @@ end end +local _ENV = TEST_CASE'setopt_user_data' if ENABLE then + +local c + +function setup() + if c then c:close() end + c = nil +end + +function test_data() + c = assert(curl.easy()) + assert_nil(c:getdata()) + c:setdata("hello") + assert_equal("hello", c:getdata()) +end + +function test_cleanup() + local ptr do + local t = {} + local e = curl.easy():setdata(t) + ptr = weak_ptr(t) + gc_collect() + + assert_equal(t, ptr.value) + end + + gc_collect() + assert_nil(ptr.value) +end + +end + RUN()