diff --git a/src/ltn12.lua b/src/ltn12.lua index ed3449b..dac932b 100644 --- a/src/ltn12.lua +++ b/src/ltn12.lua @@ -34,8 +34,23 @@ end local function chain2(f1, f2) if type(f1) ~= 'function' then error('invalid filter', 2) end if type(f2) ~= 'function' then error('invalid filter', 2) end + local co = coroutine.create(function(chunk) + while true do + local filtered1 = f1(chunk) + local filtered2 = f2(filtered1) + local done2 = filtered1 and "" + while true do + if filtered2 == "" or filtered2 == nil then break end + coroutine.yield(filtered2) + filtered2 = f2(done2) + end + if filtered1 == "" then chunk = coroutine.yield(filtered1) + elseif filtered1 == nil then return nil + else chunk = chunk and "" end + end + end) return function(chunk) - return f2(f1(chunk)) + return shift(coroutine.resume(co, chunk)) end end diff --git a/test/mimetest.lua b/test/mimetest.lua index 4a0a20a..3e57557 100644 --- a/test/mimetest.lua +++ b/test/mimetest.lua @@ -34,11 +34,52 @@ local mao = [[ local function random(handle, io_err) if handle then return function() - local chunk = handle:read(math.random(0, 1024)) + local len = math.random(0, 1024) + local chunk = handle:read(len) if not chunk then handle:close() end return chunk end - else source.empty(io_err or "unable to open file") end + else return ltn12.source.empty(io_err or "unable to open file") end +end + +local function format(chunk) + if chunk then + if chunk == "" then return "''" + else return string.len(chunk) end + else return "nil" end +end + +local function show(name, input, output) + local sin = format(input) + local sout = format(output) + io.write(name, ": ", sin, " -> ", sout, "\n") +end + +local function chunked(length) + local tmp + return function(chunk) + local ret + if chunk and chunk ~= "" then + tmp = chunk + end + ret = string.sub(tmp, 1, length) + tmp = string.sub(tmp, length+1) + if not chunk and ret == "" then ret = nil end + return ret + end +end + +--[[ +local function named(f, name) + return function(chunk) + local ret = f(chunk) + show(name, chunk, ret) + return ret + end +end +]] +local function named(f) + return f end local what = nil @@ -153,11 +194,11 @@ local function encode_b64test() end local function decode_b64test() - local d1 = mime.decode("base64") - local d2 = mime.decode("base64") - local d3 = mime.decode("base64") - local d4 = mime.decode("base64") - local chain = ltn12.filter.chain(d1, d2, d3, d4) + local d1 = named(mime.decode("base64"), "d1") + local d2 = named(mime.decode("base64"), "d2") + local d3 = named(mime.decode("base64"), "d3") + local d4 = named(mime.decode("base64"), "d4") + local chain = named(ltn12.filter.chain(d1, d2, d3, d4), "chain") transform(eb64test, db64test, chain) end diff --git a/test/testsupport.lua b/test/testsupport.lua index ca3cd95..acad8f5 100644 --- a/test/testsupport.lua +++ b/test/testsupport.lua @@ -1,5 +1,5 @@ function readfile(name) - local f = io.open(name, "r") + local f = io.open(name, "rb") if not f then return nil end local s = f:read("*a") f:close()