From a038558b44c252161dc0732f6cf759057d408c48 Mon Sep 17 00:00:00 2001 From: Alexey Melnichuk Date: Thu, 7 Apr 2016 16:59:50 +0300 Subject: [PATCH] Fix. call all callback from coroutine where perform was called. --- src/lceasy.c | 20 +++++++- src/lceasy.h | 8 +++ src/lchttppost.c | 3 +- src/lchttppost.h | 7 ++- src/lcmulti.c | 15 ++++++ test/test_curl.lua | 2 +- test/test_easy.lua | 121 ++++++++++++++++++++++++++++++++++++++++++++- 7 files changed, 169 insertions(+), 7 deletions(-) diff --git a/src/lceasy.c b/src/lceasy.c index 2819457..35a1593 100644 --- a/src/lceasy.c +++ b/src/lceasy.c @@ -41,7 +41,10 @@ int lcurl_easy_create(lua_State *L, int error_mode){ p->err_mode = error_mode; if(!p->curl) return lcurl_fail_ex(L, p->err_mode, LCURL_ERROR_EASY, CURLE_FAILED_INIT); - p->L = L; + + p->magic = LCURL_EASY_MAGIC; + p->L = NULL; + p->post = NULL; p->storage = lcurl_storage_init(L); p->wr.cb_ref = p->wr.ud_ref = LUA_NOREF; p->rd.cb_ref = p->rd.ud_ref = LUA_NOREF; @@ -115,6 +118,12 @@ static int lcurl_easy_perform(lua_State *L){ assert(p->rbuffer.ref == LUA_NOREF); + // store reference to current coroutine to callbacks + p->L = L; + if(p->post){ + p->post->L = L; + } + code = curl_easy_perform(p->curl); if(p->rbuffer.ref != LUA_NOREF){ @@ -308,6 +317,8 @@ static int lcurl_easy_set_HTTPPOST(lua_State *L){ curl_easy_setopt(p->curl, CURLOPT_READFUNCTION, lcurl_hpost_read_callback); } + p->post = post; + lua_settop(L, 1); return 1; } @@ -420,6 +431,8 @@ static int lcurl_easy_unset_HTTPPOST(lua_State *L){ lcurl_storage_remove_i(L, p->storage, CURLOPT_HTTPPOST); } + p->post = NULL; + lua_settop(L, 1); return 1; } @@ -758,12 +771,15 @@ static size_t lcurl_read_callback(lua_State *L, static size_t lcurl_easy_read_callback(char *buffer, size_t size, size_t nitems, void *arg){ lcurl_easy_t *p = arg; + if(p->magic == LCURL_HPOST_STREAM_MAGIC){ + return lcurl_hpost_read_callback(buffer, size, nitems, arg); + } return lcurl_read_callback(p->L, &p->rd, &p->rbuffer, buffer, size, nitems); } static size_t lcurl_hpost_read_callback(char *buffer, size_t size, size_t nitems, void *arg){ lcurl_hpost_stream_t *p = arg; - return lcurl_read_callback(p->L, &p->rd, &p->rbuffer, buffer, size, nitems); + return lcurl_read_callback(*p->L, &p->rd, &p->rbuffer, buffer, size, nitems); } static int lcurl_easy_set_READFUNCTION(lua_State *L){ diff --git a/src/lceasy.h b/src/lceasy.h index ee28bb1..c16e0ee 100644 --- a/src/lceasy.h +++ b/src/lceasy.h @@ -32,11 +32,19 @@ enum { #undef LCURL_LNG_INDEX #undef OPT_ENTRY +#define LCURL_EASY_MAGIC 0xEA + +typedef struct lcurl_hpost_tag lcurl_hpost_t; + typedef struct lcurl_easy_tag{ + unsigned char magic; + lua_State *L; lcurl_callback_t rd; lcurl_read_buffer_t rbuffer; + lcurl_hpost_t *post; + CURL *curl; int storage; int lists[LCURL_LIST_COUNT]; diff --git a/src/lchttppost.c b/src/lchttppost.c index cebecbd..1ceff82 100644 --- a/src/lchttppost.c +++ b/src/lchttppost.c @@ -23,7 +23,8 @@ static lcurl_hpost_stream_t *lcurl_hpost_stream_add(lua_State *L, lcurl_hpost_t lcurl_hpost_stream_t *stream = malloc(sizeof(lcurl_hpost_stream_t)); if(!stream) return NULL; - stream->L = L; + stream->magic = LCURL_HPOST_STREAM_MAGIC; + stream->L = &p->L; stream->rbuffer.ref = LUA_NOREF; stream->rd.cb_ref = stream->rd.ud_ref = LUA_NOREF; stream->next = NULL; diff --git a/src/lchttppost.h b/src/lchttppost.h index cf618b8..02cace7 100644 --- a/src/lchttppost.h +++ b/src/lchttppost.h @@ -15,14 +15,19 @@ #include "lcutils.h" #include +#define LCURL_HPOST_STREAM_MAGIC 0xAA + typedef struct lcurl_hpost_stream_tag{ - lua_State *L; + unsigned char magic; + + lua_State **L; lcurl_callback_t rd; lcurl_read_buffer_t rbuffer; struct lcurl_hpost_stream_tag *next; }lcurl_hpost_stream_t; typedef struct lcurl_hpost_tag{ + lua_State *L; struct curl_httppost *post; struct curl_httppost *last; int storage; diff --git a/src/lcmulti.c b/src/lcmulti.c index 588507e..7aee895 100644 --- a/src/lcmulti.c +++ b/src/lcmulti.c @@ -118,6 +118,21 @@ static int lcurl_multi_perform(lua_State *L){ lcurl_multi_t *p = lcurl_getmulti(L); int running_handles = 0; CURLMcode code; + + lua_settop(L, 1); + lua_rawgeti(L, LCURL_LUA_REGISTRY, p->h_ref); + lua_pushnil(L); + while(lua_next(L, 2)){ + lcurl_easy_t *e = lcurl_geteasy_at(L, -1); + e->L = L; + if(e->post){ + e->post->L = L; + } + lua_pop(L, 1); + } + + lua_settop(L, 1); + while((code = curl_multi_perform(p->curl, &running_handles)) == CURLM_CALL_MULTI_PERFORM); if(code != CURLM_OK){ lcurl_fail_ex(L, p->err_mode, LCURL_ERROR_MULTI, code); diff --git a/test/test_curl.lua b/test/test_curl.lua index 39c6892..cc5d2c7 100644 --- a/test/test_curl.lua +++ b/test/test_curl.lua @@ -63,7 +63,7 @@ function test_add_handle() end end - m = assert_equal(m, m:add_handle(next_easy())) + assert_equal(m, m:add_handle(next_easy())) for data, type, easy in m:iperform() do diff --git a/test/test_easy.lua b/test/test_easy.lua index 6efcbba..d11b6e2 100644 --- a/test/test_easy.lua +++ b/test/test_easy.lua @@ -64,6 +64,29 @@ local function strem(ch, n, m) return n, get_bin_by( (ch):rep(n), m) end +local function Stream(ch, n, m) + local size, reader + + local _stream = {} + + function _stream:read(...) + _stream.called_ctx = self + _stream.called_co = coroutine.running() + return reader(...) + end + + function _stream:size() + return size + end + + function _stream:reset() + size, reader = strem(ch, n, m) + return self + end + + return _stream:reset() +end + local ENABLE = true local _ENV = TEST_CASE'write_callback' if ENABLE then @@ -165,6 +188,31 @@ function test_write_pass_03() assert_equal(c, c:perform()) end +function test_write_coro() + local co1, co2 + local called + + co1 = coroutine.create(function() + c = assert(curl.easy{ + url = url; + writefunction = function() + called = coroutine.running() + return true + end + }) + coroutine.yield() + end) + + co2 = coroutine.create(function() + assert_equal(c, c:perform()) + end) + + coroutine.resume(co1) + coroutine.resume(co2) + + assert_equal(co2, called) +end + end local _ENV = TEST_CASE'progress_callback' if ENABLE then @@ -380,7 +428,7 @@ local _ENV = TEST_CASE'read_stream_callback' if ENABLE and is_curl_ge(7,30,0) th local url = "http://httpbin.org/post" -local c, f, t +local m, c, f, t local function json_data() return json.decode(table.concat(t)) @@ -399,19 +447,88 @@ end function teardown() if f then f:free() end if c then c:close() end - t, f, c = nil + if m then m:close() end + t, f, c, m = nil end function test() assert_equal(f, f:add_stream('SSSSS', strem('X', 128, 13))) assert_equal(c, c:setopt_httppost(f)) + + -- should be called only stream callback + local read_called + assert_equal(c, c:setopt_readfunction(function() + read_called = true + end)) + assert_equal(c, c:perform()) + + assert_nil(read_called) + assert_equal(200, c:getinfo_response_code()) local data = assert_table(json_data()) assert_table(data.form) assert_equal(('X'):rep(128), data.form.SSSSS) end +function test_object() + local s = Stream('X', 128, 13) + + assert_equal(f, f:add_stream('SSSSS', s:size(), s)) + assert_equal(c, c:setopt_httppost(f)) + assert_equal(c, c:perform()) + + assert_equal(s, s.called_ctx) + + assert_equal(200, c:getinfo_response_code()) + local data = assert_table(json_data()) + assert_table(data.form) + assert_equal(('X'):rep(128), data.form.SSSSS) +end + +function test_co_multi() + local s = Stream('X', 128, 13) + assert_equal(f, f:add_stream('SSSSS', s:size(), s)) + assert_equal(c, c:setopt_httppost(f)) + + m = assert(scurl.multi()) + assert_equal(m, m:add_handle(c)) + + co = coroutine.create(function() + while 1== m:perform() do end + end) + + coroutine.resume(co) + + assert_equal(co, s.called_co) + + assert_equal(200, c:getinfo_response_code()) + local data = assert_table(json_data()) + assert_table(data.form) + assert_equal(('X'):rep(128), data.form.SSSSS) +end + +function test_co() + local s = Stream('X', 128, 13) + + assert_equal(f, f:add_stream('SSSSS', s:size(), s)) + assert_equal(c, c:setopt_httppost(f)) + + co = coroutine.create(function() + assert_equal(c, c:perform()) + end) + + coroutine.resume(co) + + assert_equal(co, s.called_co) + + assert_equal(200, c:getinfo_response_code()) + local data = assert_table(json_data()) + assert_table(data.form) + assert_equal(('X'):rep(128), data.form.SSSSS) + +end + function test_abort_01() assert_equal(f, f:add_stream('SSSSS', 128 * 1024, function() end)) assert_equal(c, c:setopt_timeout(5))