Fix. call all callback from coroutine where perform was called.

This commit is contained in:
Alexey Melnichuk 2016-04-07 16:59:50 +03:00
parent f42a0e48b1
commit a038558b44
7 changed files with 169 additions and 7 deletions

View File

@ -41,7 +41,10 @@ int lcurl_easy_create(lua_State *L, int error_mode){
p->err_mode = error_mode;
if(!p->curl) return lcurl_fail_ex(L, p->err_mode, LCURL_ERROR_EASY, CURLE_FAILED_INIT);
p->L = L;
p->magic = LCURL_EASY_MAGIC;
p->L = NULL;
p->post = NULL;
p->storage = lcurl_storage_init(L);
p->wr.cb_ref = p->wr.ud_ref = LUA_NOREF;
p->rd.cb_ref = p->rd.ud_ref = LUA_NOREF;
@ -115,6 +118,12 @@ static int lcurl_easy_perform(lua_State *L){
assert(p->rbuffer.ref == LUA_NOREF);
// store reference to current coroutine to callbacks
p->L = L;
if(p->post){
p->post->L = L;
}
code = curl_easy_perform(p->curl);
if(p->rbuffer.ref != LUA_NOREF){
@ -308,6 +317,8 @@ static int lcurl_easy_set_HTTPPOST(lua_State *L){
curl_easy_setopt(p->curl, CURLOPT_READFUNCTION, lcurl_hpost_read_callback);
}
p->post = post;
lua_settop(L, 1);
return 1;
}
@ -420,6 +431,8 @@ static int lcurl_easy_unset_HTTPPOST(lua_State *L){
lcurl_storage_remove_i(L, p->storage, CURLOPT_HTTPPOST);
}
p->post = NULL;
lua_settop(L, 1);
return 1;
}
@ -758,12 +771,15 @@ static size_t lcurl_read_callback(lua_State *L,
static size_t lcurl_easy_read_callback(char *buffer, size_t size, size_t nitems, void *arg){
lcurl_easy_t *p = arg;
if(p->magic == LCURL_HPOST_STREAM_MAGIC){
return lcurl_hpost_read_callback(buffer, size, nitems, arg);
}
return lcurl_read_callback(p->L, &p->rd, &p->rbuffer, buffer, size, nitems);
}
static size_t lcurl_hpost_read_callback(char *buffer, size_t size, size_t nitems, void *arg){
lcurl_hpost_stream_t *p = arg;
return lcurl_read_callback(p->L, &p->rd, &p->rbuffer, buffer, size, nitems);
return lcurl_read_callback(*p->L, &p->rd, &p->rbuffer, buffer, size, nitems);
}
static int lcurl_easy_set_READFUNCTION(lua_State *L){

View File

@ -32,11 +32,19 @@ enum {
#undef LCURL_LNG_INDEX
#undef OPT_ENTRY
#define LCURL_EASY_MAGIC 0xEA
typedef struct lcurl_hpost_tag lcurl_hpost_t;
typedef struct lcurl_easy_tag{
unsigned char magic;
lua_State *L;
lcurl_callback_t rd;
lcurl_read_buffer_t rbuffer;
lcurl_hpost_t *post;
CURL *curl;
int storage;
int lists[LCURL_LIST_COUNT];

View File

@ -23,7 +23,8 @@ static lcurl_hpost_stream_t *lcurl_hpost_stream_add(lua_State *L, lcurl_hpost_t
lcurl_hpost_stream_t *stream = malloc(sizeof(lcurl_hpost_stream_t));
if(!stream) return NULL;
stream->L = L;
stream->magic = LCURL_HPOST_STREAM_MAGIC;
stream->L = &p->L;
stream->rbuffer.ref = LUA_NOREF;
stream->rd.cb_ref = stream->rd.ud_ref = LUA_NOREF;
stream->next = NULL;

View File

@ -15,14 +15,19 @@
#include "lcutils.h"
#include <stdlib.h>
#define LCURL_HPOST_STREAM_MAGIC 0xAA
typedef struct lcurl_hpost_stream_tag{
lua_State *L;
unsigned char magic;
lua_State **L;
lcurl_callback_t rd;
lcurl_read_buffer_t rbuffer;
struct lcurl_hpost_stream_tag *next;
}lcurl_hpost_stream_t;
typedef struct lcurl_hpost_tag{
lua_State *L;
struct curl_httppost *post;
struct curl_httppost *last;
int storage;

View File

@ -118,6 +118,21 @@ static int lcurl_multi_perform(lua_State *L){
lcurl_multi_t *p = lcurl_getmulti(L);
int running_handles = 0;
CURLMcode code;
lua_settop(L, 1);
lua_rawgeti(L, LCURL_LUA_REGISTRY, p->h_ref);
lua_pushnil(L);
while(lua_next(L, 2)){
lcurl_easy_t *e = lcurl_geteasy_at(L, -1);
e->L = L;
if(e->post){
e->post->L = L;
}
lua_pop(L, 1);
}
lua_settop(L, 1);
while((code = curl_multi_perform(p->curl, &running_handles)) == CURLM_CALL_MULTI_PERFORM);
if(code != CURLM_OK){
lcurl_fail_ex(L, p->err_mode, LCURL_ERROR_MULTI, code);

View File

@ -63,7 +63,7 @@ function test_add_handle()
end
end
m = assert_equal(m, m:add_handle(next_easy()))
assert_equal(m, m:add_handle(next_easy()))
for data, type, easy in m:iperform() do

View File

@ -64,6 +64,29 @@ local function strem(ch, n, m)
return n, get_bin_by( (ch):rep(n), m)
end
local function Stream(ch, n, m)
local size, reader
local _stream = {}
function _stream:read(...)
_stream.called_ctx = self
_stream.called_co = coroutine.running()
return reader(...)
end
function _stream:size()
return size
end
function _stream:reset()
size, reader = strem(ch, n, m)
return self
end
return _stream:reset()
end
local ENABLE = true
local _ENV = TEST_CASE'write_callback' if ENABLE then
@ -165,6 +188,31 @@ function test_write_pass_03()
assert_equal(c, c:perform())
end
function test_write_coro()
local co1, co2
local called
co1 = coroutine.create(function()
c = assert(curl.easy{
url = url;
writefunction = function()
called = coroutine.running()
return true
end
})
coroutine.yield()
end)
co2 = coroutine.create(function()
assert_equal(c, c:perform())
end)
coroutine.resume(co1)
coroutine.resume(co2)
assert_equal(co2, called)
end
end
local _ENV = TEST_CASE'progress_callback' if ENABLE then
@ -380,7 +428,7 @@ local _ENV = TEST_CASE'read_stream_callback' if ENABLE and is_curl_ge(7,30,0) th
local url = "http://httpbin.org/post"
local c, f, t
local m, c, f, t
local function json_data()
return json.decode(table.concat(t))
@ -399,19 +447,88 @@ end
function teardown()
if f then f:free() end
if c then c:close() end
t, f, c = nil
if m then m:close() end
t, f, c, m = nil
end
function test()
assert_equal(f, f:add_stream('SSSSS', strem('X', 128, 13)))
assert_equal(c, c:setopt_httppost(f))
-- should be called only stream callback
local read_called
assert_equal(c, c:setopt_readfunction(function()
read_called = true
end))
assert_equal(c, c:perform())
assert_nil(read_called)
assert_equal(200, c:getinfo_response_code())
local data = assert_table(json_data())
assert_table(data.form)
assert_equal(('X'):rep(128), data.form.SSSSS)
end
function test_object()
local s = Stream('X', 128, 13)
assert_equal(f, f:add_stream('SSSSS', s:size(), s))
assert_equal(c, c:setopt_httppost(f))
assert_equal(c, c:perform())
assert_equal(s, s.called_ctx)
assert_equal(200, c:getinfo_response_code())
local data = assert_table(json_data())
assert_table(data.form)
assert_equal(('X'):rep(128), data.form.SSSSS)
end
function test_co_multi()
local s = Stream('X', 128, 13)
assert_equal(f, f:add_stream('SSSSS', s:size(), s))
assert_equal(c, c:setopt_httppost(f))
m = assert(scurl.multi())
assert_equal(m, m:add_handle(c))
co = coroutine.create(function()
while 1== m:perform() do end
end)
coroutine.resume(co)
assert_equal(co, s.called_co)
assert_equal(200, c:getinfo_response_code())
local data = assert_table(json_data())
assert_table(data.form)
assert_equal(('X'):rep(128), data.form.SSSSS)
end
function test_co()
local s = Stream('X', 128, 13)
assert_equal(f, f:add_stream('SSSSS', s:size(), s))
assert_equal(c, c:setopt_httppost(f))
co = coroutine.create(function()
assert_equal(c, c:perform())
end)
coroutine.resume(co)
assert_equal(co, s.called_co)
assert_equal(200, c:getinfo_response_code())
local data = assert_table(json_data())
assert_table(data.form)
assert_equal(('X'):rep(128), data.form.SSSSS)
end
function test_abort_01()
assert_equal(f, f:add_stream('SSSSS', 128 * 1024, function() end))
assert_equal(c, c:setopt_timeout(5))