diff --git a/.travis.yml b/.travis.yml index db03623..0967a9a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -45,6 +45,7 @@ script: - lua -e "print(require 'cURL.utils'.find_ca_bundle())" - lunit.sh run.lua - lua test_pause02.c.lua + - lua test_multi_callback.lua # - lunit.sh test_easy.lua # - lunit.sh test_safe.lua # - lunit.sh test_form.lua diff --git a/appveyor.yml b/appveyor.yml index e805dde..92a1427 100644 --- a/appveyor.yml +++ b/appveyor.yml @@ -56,6 +56,7 @@ test_script: - cd %APPVEYOR_BUILD_FOLDER%\test - lua run.lua - lua test_pause02.c.lua + - lua test_multi_callback.lua after_test: - cd %APPVEYOR_BUILD_FOLDER% diff --git a/examples/cURLv3/uvwget.lua b/examples/cURLv3/uvwget.lua index 4119ae8..8341e1d 100644 --- a/examples/cURLv3/uvwget.lua +++ b/examples/cURLv3/uvwget.lua @@ -182,7 +182,7 @@ end on_curl_action = function(easy, fd, action) local ok, err = pcall(function() - trace("CURL::SOCKET", easy, s, ACTION_NAMES[action] or action) + trace("CURL::SOCKET", easy, fd, ACTION_NAMES[action] or action) local context = easy.data.context diff --git a/src/lceasy.c b/src/lceasy.c index a0f89ad..82ef5ee 100644 --- a/src/lceasy.c +++ b/src/lceasy.c @@ -14,6 +14,7 @@ #include "lcutils.h" #include "lchttppost.h" #include "lcshare.h" +#include "lcmulti.h" #include static const char *LCURL_ERROR_TAG = "LCURL_ERROR_TAG"; @@ -45,6 +46,7 @@ int lcurl_easy_create(lua_State *L, int error_mode){ p->magic = LCURL_EASY_MAGIC; p->L = NULL; p->post = NULL; + p->multi = NULL; p->storage = lcurl_storage_init(L); p->wr.cb_ref = p->wr.ud_ref = LUA_NOREF; p->rd.cb_ref = p->rd.ud_ref = LUA_NOREF; @@ -75,7 +77,19 @@ static int lcurl_easy_cleanup(lua_State *L){ int i; if(p->curl){ + p->L = L; + if(p->post){ + p->post->L = L; + } + // In my tests when I cleanup some easy handle. + // timerfunction called only for single multi handle. + if(p->multi){ + p->multi->L = L; + } curl_easy_cleanup(p->curl); + if(p->multi){ + p->multi->L = NULL; + } p->curl = NULL; } diff --git a/src/lceasy.h b/src/lceasy.h index e64f2b4..1093a65 100644 --- a/src/lceasy.h +++ b/src/lceasy.h @@ -35,6 +35,8 @@ enum { #define LCURL_EASY_MAGIC 0xEA +typedef struct lcurl_multi_tag lcurl_multi_t; + typedef struct lcurl_easy_tag{ unsigned char magic; @@ -44,6 +46,8 @@ typedef struct lcurl_easy_tag{ lcurl_hpost_t *post; + lcurl_multi_t *multi; + CURL *curl; int storage; int lists[LCURL_LIST_COUNT]; diff --git a/src/lcmulti.c b/src/lcmulti.c index 7aee895..2213f33 100644 --- a/src/lcmulti.c +++ b/src/lcmulti.c @@ -28,18 +28,36 @@ #define LCURL_MULTI_NAME LCURL_PREFIX" Multi" static const char *LCURL_MULTI = LCURL_MULTI_NAME; +static void lcurl__multi_assign_lua(lua_State *L, lcurl_multi_t *p, int assign_easy){ + p->L = L; + + if(assign_easy){ + lua_rawgeti(L, LCURL_LUA_REGISTRY, p->h_ref); + lua_pushnil(L); + while(lua_next(L, -2)){ + lcurl_easy_t *e = lcurl_geteasy_at(L, -1); + e->L = L; + if(e->post){ + e->post->L = L; + } + lua_pop(L, 1); + } + lua_pop(L, 1); + } +} + //{ int lcurl_multi_create(lua_State *L, int error_mode){ lcurl_multi_t *p; - + lua_settop(L, 1); p = lutil_newudatap(L, lcurl_multi_t, LCURL_MULTI); p->curl = curl_multi_init(); p->err_mode = error_mode; if(!p->curl) return lcurl_fail_ex(L, p->err_mode, LCURL_ERROR_MULTI, CURLM_INTERNAL_ERROR); - p->L = L; + p->L = NULL; lcurl_util_new_weak_table(L, "v"); p->h_ref = luaL_ref(L, LCURL_LUA_REGISTRY); p->tm.cb_ref = p->tm.ud_ref = LUA_NOREF; @@ -89,24 +107,65 @@ static int lcurl_multi_cleanup(lua_State *L){ static int lcurl_multi_add_handle(lua_State *L){ lcurl_multi_t *p = lcurl_getmulti(L); lcurl_easy_t *e = lcurl_geteasy_at(L, 2); - CURLMcode code = curl_multi_add_handle(p->curl, e->curl); - if(code != CURLM_OK){ - lcurl_fail_ex(L, p->err_mode, LCURL_ERROR_MULTI, code); + CURLMcode code; + + if(e->multi){ + return lcurl_fail_ex(L, p->err_mode, LCURL_ERROR_MULTI, CURLM_ADDED_ALREADY); } + + // From doc: + // If you have CURLMOPT_TIMERFUNCTION set in the multi handle, + // that callback will be called from within this function to ask + // for an updated timer so that your main event loop will get + // the activity on this handle to get started. + // + // So we should add easy before this call + // call chain may be like => timerfunction->socket_action->socketfunction + lua_settop(L, 2); lua_rawgeti(L, LCURL_LUA_REGISTRY, p->h_ref); lua_pushvalue(L, 2); lua_rawsetp(L, -2, e->curl); lua_settop(L, 1); + + e->multi = p; + + lcurl__multi_assign_lua(L, p, 0); + code = curl_multi_add_handle(p->curl, e->curl); + p->L = NULL; + + if(code != CURLM_OK){ + // remove + lua_rawgeti(L, LCURL_LUA_REGISTRY, p->h_ref); + lua_pushnil(L); + lua_rawsetp(L, -2, e->curl); + e->multi = NULL; + + return lcurl_fail_ex(L, p->err_mode, LCURL_ERROR_MULTI, code); + } return 1; } static int lcurl_multi_remove_handle(lua_State *L){ lcurl_multi_t *p = lcurl_getmulti(L); lcurl_easy_t *e = lcurl_geteasy_at(L, 2); - CURLMcode code = curl_multi_remove_handle(p->curl, e->curl); + CURLMcode code; + + if(e->multi != p){ + // cURL returns CURLM_OK for such call so we do the same. + // tested on 7.37.1 + lua_settop(L, 1); + return 1; + } + + lcurl__multi_assign_lua(L, p, 0); + code = curl_multi_remove_handle(p->curl, e->curl); + p->L = NULL; + if(code != CURLM_OK){ lcurl_fail_ex(L, p->err_mode, LCURL_ERROR_MULTI, code); } + + e->multi = NULL; lua_rawgeti(L, LCURL_LUA_REGISTRY, p->h_ref); lua_pushnil(L); lua_rawsetp(L, -2, e->curl); @@ -119,21 +178,10 @@ static int lcurl_multi_perform(lua_State *L){ int running_handles = 0; CURLMcode code; - lua_settop(L, 1); - lua_rawgeti(L, LCURL_LUA_REGISTRY, p->h_ref); - lua_pushnil(L); - while(lua_next(L, 2)){ - lcurl_easy_t *e = lcurl_geteasy_at(L, -1); - e->L = L; - if(e->post){ - e->post->L = L; - } - lua_pop(L, 1); - } - - lua_settop(L, 1); - + lcurl__multi_assign_lua(L, p, 1); while((code = curl_multi_perform(p->curl, &running_handles)) == CURLM_CALL_MULTI_PERFORM); + p->L = NULL; + if(code != CURLM_OK){ lcurl_fail_ex(L, p->err_mode, LCURL_ERROR_MULTI, code); } @@ -256,7 +304,11 @@ static int lcurl_multi_socket_action(lua_State *L){ CURLMcode code; int n, mask; if(s == CURL_SOCKET_TIMEOUT) mask = lutil_optint64(L, 3, 0); else mask = lutil_checkint64(L, 3); + + lcurl__multi_assign_lua(L, p, 0); code = curl_multi_socket_action(p->curl, s, mask, &n); + p->L = NULL; + if(code != CURLM_OK){ lcurl_fail_ex(L, p->err_mode, LCURL_ERROR_MULTI, code); } @@ -412,6 +464,7 @@ static int lcurl_multi_socket_callback(CURL *easy, curl_socket_t s, int what, vo lutil_pushint64(L, s); lua_pushinteger(L, what); + e->L = L; if(lua_pcall(L, n+2, 0, 0)){ assert(lua_gettop(L) >= top); lua_settop(L, top); diff --git a/test/test_multi_callback.lua b/test/test_multi_callback.lua new file mode 100644 index 0000000..3a81b00 --- /dev/null +++ b/test/test_multi_callback.lua @@ -0,0 +1,109 @@ +local curl = require "lcurl" + +local called, active_coroutine = 0 + +function on_timer() + called = called + 1 + -- use `os.exit` because now Lua-cURL did not propogate error from callback + if coroutine.running() ~= active_coroutine then os.exit(-1) end +end + +local function test_1() + io.write('Test #1 - ') + + called, active_coroutine = 0 + + local e = curl.easy() + local m = curl.multi{ timerfunction = on_timer } + + active_coroutine = coroutine.create(function() + m:add_handle(e) + end) + + coroutine.resume(active_coroutine) + assert(called == 1) + + active_coroutine = nil + m:remove_handle(e) + assert(called == 2) + + io.write('pass!\n') +end + +local function test_2() + io.write('Test #2 - ') + + called, active_coroutine = 0 + + local e = curl.easy() + local m = curl.multi{ timerfunction = on_timer } + + active_coroutine = coroutine.create(function() + m:add_handle(e) + end) + + coroutine.resume(active_coroutine) + assert(called == 1) + + active_coroutine = coroutine.create(function() + m:remove_handle(e) + end) + coroutine.resume(active_coroutine) + assert(called == 2) + + io.write('pass!\n') +end + +local function test_3() + io.write('Test #3 - ') + + called, active_coroutine = 0 + + local e = curl.easy() + local m = curl.multi{ timerfunction = on_timer } + + active_coroutine = coroutine.create(function() + m:add_handle(e) + end) + + coroutine.resume(active_coroutine) + assert(called == 1) + + active_coroutine = nil + e:close() + assert(called == 2) + + io.write('pass!\n') +end + +local function test_4() + io.write('Test #4 - ') + + called, active_coroutine = 0 + + local e = curl.easy() + local m = curl.multi{ timerfunction = on_timer } + + active_coroutine = coroutine.create(function() + m:add_handle(e) + end) + + coroutine.resume(active_coroutine) + assert(called == 1) + + active_coroutine = coroutine.create(function() + e:close() + end) + coroutine.resume(active_coroutine) + assert(called == 2) + + io.write('pass!\n') +end + +test_1() + +test_2() + +test_3() + +test_4()