diff --git a/.travis.yml b/.travis.yml index 99a414e..f08fa33 100644 --- a/.travis.yml +++ b/.travis.yml @@ -31,6 +31,7 @@ script: - lunit.sh test_safe.lua - lunit.sh test_form.lua - lunit.sh test_pause02.c.lua + - lunit.sh test_curl.lua after_success: - coveralls -b .. -r .. diff --git a/examples/cURLv3/multi3.lua b/examples/cURLv3/multi3.lua new file mode 100644 index 0000000..fff826c --- /dev/null +++ b/examples/cURLv3/multi3.lua @@ -0,0 +1,28 @@ +local cURL = require("cURL") + +local urls = { + "http://httpbin.org/get?key=1", + "http://httpbin.org/get?key=2", + "http://httpbin.org/get?key=3", + "http://httpbin.org/get?key=4", +} + +local function next_easy() + local url = table.remove(urls, 1) + if url then return cURL.easy{url = url} end +end + +m = cURL.multi():add_handle(next_easy()) +for data, type, easy in m:iperform() do + + if type == "done" or type == "error" then + print("Done", easy:getinfo_effective_url(), ":", data) + easy:close() + easy = next_easy() + if easy then m:add_handle(easy) end + end + + if type == "data" then print(data) end + +end + diff --git a/lakefile b/lakefile index 495ac05..1a2530b 100644 --- a/lakefile +++ b/lakefile @@ -29,6 +29,7 @@ target('test', install, function() run_test('test_safe.lua') run_test('test_form.lua') run_test('test_pause02.c.lua') + run_test('test_curl.lua') if not test_summary() then quit("test fail") diff --git a/src/lua/cURL/impl/cURL.lua b/src/lua/cURL/impl/cURL.lua index f7f7b90..c1bb1f8 100644 --- a/src/lua/cURL/impl/cURL.lua +++ b/src/lua/cURL/impl/cURL.lua @@ -36,10 +36,8 @@ local function wrap_setopt_flags(k, flags) end end -local function make_iterator(self, perform) - local curl = require "lcurl.safe" - - local buffers = {resp = {}, _ = {}} do +local function new_buffers() + local buffers = {resp = {}, _ = {}} function buffers:append(e, ...) local resp = assert(e:getinfo_response_code()) @@ -62,19 +60,35 @@ local function make_iterator(self, perform) end end + return buffers +end + +local function make_iterator(self, perform) + local curl = require "lcurl.safe" + + local buffers = new_buffers() + + -- reset callbacks to all easy handles + local function reset_easy(self) + if not self._easy_mark then -- that means we have some new easy handles + for h, e in pairs(self._easy) do if h ~= 'n' then + e:setopt_writefunction (function(str) buffers:append(e, "data", str) end) + e:setopt_headerfunction(function(str) buffers:append(e, "header", str) end) + end end + self._easy_mark = true + end + return self._easy.n end - local remain = self._easy.n - for h, e in pairs(self._easy) do - if h ~= 'n' then - e:setopt_writefunction (function(str) buffers:append(e, "data", str) end) - e:setopt_headerfunction(function(str) buffers:append(e, "header", str) end) - end - end + if 0 == reset_easy(self) then return end assert(perform(self)) return function() + -- we can add new handle during iteration + local remain = reset_easy(self) + + -- wait next event while true do local e, t = buffers:next() if t then return t[2], t[1], e end @@ -82,7 +96,7 @@ local function make_iterator(self, perform) self:wait() - local n, err = assert(perform(self)) + local n = assert(perform(self)) if n <= remain then while true do @@ -95,8 +109,9 @@ local function make_iterator(self, perform) else buffers:append(e, "error", err) end self:remove_handle(e) end - remain = n end + + remain = n end end end @@ -502,6 +517,8 @@ function Multi:add_handle(e) local ok, err = add_handle(self, h) if not ok then return nil, err end self._easy[h], self._easy.n = e, self._easy.n + 1 + self._easy_mark = nil + return self end diff --git a/test/test_curl.lua b/test/test_curl.lua new file mode 100644 index 0000000..2282147 --- /dev/null +++ b/test/test_curl.lua @@ -0,0 +1,89 @@ +local HAS_RUNNER = not not lunit +local lunit = require "lunit" +local TEST_CASE = assert(lunit.TEST_CASE) +local skip = lunit.skip or function() end + +local curl = require "cURL" +local scurl = require "cURL.safe" +local json = require "dkjson" +local fname = "./test.download" + +local ENABLE = true + +local _ENV = TEST_CASE'multi_iterator' if ENABLE then + +local url = "http://httpbin.org/get" + +local c, t, m + +local function json_data() + return json.decode(table.concat(t)) +end + +function setup() + t = {} + m = assert(scurl.multi()) +end + +function teardown() + if m then m:close() end + if c then c:close() end + m, c, t = nil +end + +function test_add_handle() + + local base_url = 'http://httpbin.org/get?key=' + local urls = { + base_url .. "1", + base_url .. "2", + "###" .. base_url .. "3", + base_url .. "4", + base_url .. "5", + } + + local i = 0 + local function next_easy() + i = i + 1 + local url = urls[i] + if url then + c = assert(scurl.easy{url = url}) + t = {} + return c + end + end + + m = assert_equal(m, m:add_handle(next_easy())) + + for data, type, easy in m:iperform() do + + if type == "done" or type == "error" then + assert_equal(urls[i], easy:getinfo_effective_url()) + assert_equal(easy, c) + easy:close() + c = nil + + if i == 3 then + assert_equal(curl.error(curl.ERROR_EASY, curl.E_UNSUPPORTED_PROTOCOL), data) + else + local data = json_data() + assert_table(data.args) + assert_equal(tostring(i), data.args.key) + end + + easy = next_easy() + if easy then m:add_handle(easy) end + end + + if type == "data" then table.insert(t, data) end + + end + + assert_equal(#urls + 1, i) + assert_nil(c) +end + +end + + +if not HAS_RUNNER then lunit.run() end