2024-02-21 14:36:19 +01:00

627 lines
15 KiB
Lua

--- Bytecode parsing.
-- Please note that this module is experimental and subject to change.
-- @module advtrains_doc_integration.bc
local bc = {}
local band = bit.band
local brshift = bit.rshift
local function read_u8(str, pos)
return string.byte(str, pos, pos)
end
local function read_u16le(str, pos)
local l, u = string.byte(str, pos, pos+1)
return u*32+l
end
local function read_u32le(str, pos)
local b0, b1, b2, b3 = string.byte(str, pos, pos+3)
return b3*256^3+b2*256^2+b1*256+b0
end
local function construct_double(lo, hi)
local exp = band(brshift(hi, 20), 0x7ff)
local mat = (band(hi, 0xfffff)+lo/2^32)/0x100000
local sign = brshift(hi, 31)
sign = (-1)^sign
if exp == 0 then
return sign*math.ldexp(mat, exp-1022)
elseif exp == 0x7ff then
if mat == 0 then
return sign*math.huge
end
return nil
end
return sign*math.ldexp(1+mat, exp-1023)
end
local function readflags(val, spec)
local t = {}
for k, v in pairs(spec) do
local b = band(val, v)
if b ~= 0 then
t[k] = b
end
end
return t
end
local function read_lj_uleb128(str, pos)
local b = read_u8(str, pos)
local v = 0
local count = 0
while b >= 128 do
v = v + (b%128)*128^count
count = count + 1
b = read_u8(str, pos+count)
end
return v+b*128^count, pos+count+1
end
local function read_lj_uleb128_33(str, pos)
local b = brshift(read_u8(str, pos), 1)
if b >= 64 then
local up, p2 = read_lj_uleb128(str, pos+1)
return b%64+64*up, p2
else
return b, pos+1
end
end
local function read_lj_double(str, pos)
local lo, p2 = read_lj_uleb128(str, pos)
local hi, p3 = read_lj_uleb128(str, p2)
local d = construct_double(lo, hi)
if not d then
return nil, ("Bad double: 0x%08x%08x"):format(hi, lo)
end
return d, p3
end
local function read_lj_int_or_double(str, pos)
if read_u8(str, pos) % 2 == 0 then
return read_lj_uleb128_33(str, pos)
else
local lo, p2 = read_lj_uleb128_33(str, pos)
local hi, p3 = read_lj_uleb128(str, p2)
local d = construct_double(lo, hi)
if not d then
return nil, ("Bad double: 0x%08x%08x"):format(hi, lo)
end
return d, p3
end
end
local lj_bcdef = {
[0] = {"ISLT", "var", nil, "var"},
{"ISGE", "var", nil, "var"},
{"ISLE", "var", nil, "var"},
{"ISGT", "var", nil, "var"},
--
{"ISEQV", "var", nil, "var"},
{"ISNEV", "var", nil, "var"},
{"ISEQS", "var", nil, "str"},
{"ISNES", "var", nil, "str"},
{"ISEQN", "var", nil, "num"},
{"ISNEN", "var", nil, "num"},
{"ISEQP", "var", nil, "pri"},
{"ISNEP", "var", nil, "pri"},
--
{"ISTC", "dst", nil, "var"},
{"ISFC", "dst", nil, "var"},
{"IST", nil, nil, "var"},
{"ISF", nil, nil, "var"},
{"ISTYPE", "var", nil, "lit"},
{"ISNUM", "var", nil, "lit"},
--
{"MOV", "dst", nil, "var"},
{"NOT", "dst", nil, "var"},
{"UNM", "dst", nil, "var"},
{"LEN", "dst", nil, "var"},
}
for _, var in ipairs {
{"VN", "dst", "var", "num"},
{"NV", "dst", "var", "num"},
{"VV", "dst", "num", "var"},
} do
for _, ins in ipairs {"ADD", "SUB", "MUL", "DIV", "MOD"} do
table.insert(lj_bcdef, {ins..var[1], unpack(var, 2, 4)})
end
end
for _, ent in ipairs {
{"POW", "dst", "var", "var"},
{"CAT", "dst", "rbase", "rbase"},
--
{"KSTR", "dst", nil, "str"},
{"KCDATA", "dst", nil, "cdata"},
{"KSHORT", "dst", nil, "lits"},
{"KNUM", "dst", nil, "num"},
{"KPRI", "dst", nil, "pri"},
{"KNIL", "dst", nil, "base"},
{"UGET", "dst", nil, "uv"},
{"USETV", "uv", nil, "var"},
{"USETS", "uv", nil, "str"},
{"USETN", "uv", nil, "num"},
{"USETP", "uv", nil, "pri"},
{"UCLO", "rbase", nil, "jump"},
{"FNEW", "dst", nil, "func"},
--
{"TNEW", "dst", nil, "lit"},
{"TDUP", "dst", nil, "tab"},
{"GGET", "dst", nil, "str"},
{"GSET", "var", nil, "str"},
{"TGETV", "dst", "var", "var"},
{"TGETS", "dst", "var", "str"},
{"TGETB", "dst", "var", "lit"},
{"TGETR", "dst", "var", "var"},
{"TSETV", "var", "var", "var"},
{"TSETS", "var", "var", "str"},
{"TSETB", "var", "var", "lit"},
{"TSETM", "base", nil, "num"},
{"TSETR", "var", "var", "var"},
--
{"CALLM", "base", "lit", "lit"},
{"CALL", "base", "lit", "lit"},
{"CALLMT", "base", nil, "lit"},
{"CALLT", "base", nil, "lit"},
{"ITERC", "base", "lit", "lit"},
{"ITERN", "base", "lit", "lit"},
{"VARG", "base", "lit", "lit"},
{"ISNEXT", "base", nil, "jump"},
--
{"RETM", "base", nil, "lit"},
{"RET", "rbase", nil, "lit"},
{"RET0", "rbase", nil, "lit"},
{"RET1", "rbase", nil, "lit"},
--
{"FORI", "base", nil, "jump"},
{"JFORI", "base", nil, "jump"},
--
{"FORL", "base", nil, "jump"},
{"IFORL", "base", nil, "jump"},
{"JFORL", "base", nil, "lit"},
--
{"ITERL", "base", nil, "jump"},
{"IITERL", "base", nil, "jump"},
{"JITERL", "base", nil, "lit"},
--
{"LOOP", "rbase", nil, "jump"},
{"ILOOP", "rbase", nil, "jump"},
{"JLOOP", "rbase", nil, "lit"},
--
{"JMP", "rbase", nil, "jump"},
--
{"FUNCF", "rbase", nil, nil},
{"IFUNCF", "rbase", nil, nil},
{"JFUNCF", "rbase", nil, "lit"},
{"FUNCV", "rbase", nil, nil},
{"IFUNCV", "rbase", nil, nil},
{"JFUNCV", "rbase", nil, "lit"},
{"FUNCC", "rbase", nil, nil},
{"FUNCCW", "rbase", nil, nil},
} do
table.insert(lj_bcdef, ent)
end
local function lj_read_nbytes(dump, pos)
local len, p2 = read_lj_uleb128(dump, pos)
if len == 0 then
return "", p2
else
local val = string.sub(dump, p2, p2+len-1)
return val, p2+len
end
end
local function lj_parse_bytecode(phead, pos, pstr, _)
local inslist = {}
for k = 1, phead.numbc do
local w = read_u32le(pstr, pos+4*k-4)
local op = lj_bcdef[w%256]
if not op then
return nil, ("Invalid opcode: %02X"):format(op)
end
local ins = {op[1], {type = op[2], value = math.floor(w/256)%256}}
if op[3] then
ins[3] = {type = op[3], value = math.floor(w/256^3)%256}
ins[4] = {type = op[4], value = math.floor(w/256^2)%256}
else
ins[3] = {type = op[4], value = math.floor(w/256^2)}
end
inslist[k] = ins
end
return inslist, pos+4*phead.numbc
end
local function lj_parse_uv(phead, pos, pstr, _)
local uvlist = {}
for k = 1, phead.numuv do
uvlist[k] = read_u16le(pstr, pos+2*k-2)
end
return uvlist, pos+2*phead.numuv
end
local lj_ktab_type = {
["nil"] = 0,
["false"] = 1,
["true"] = 2,
int = 3,
num = 4,
str = 5,
}
local function lj_parse_ktabk(_, pos, pstr, _)
local tp, p2 = read_lj_uleb128(pstr, pos)
if tp >= lj_ktab_type.str then
local len = tp - lj_ktab_type.str
return string.sub(pstr, p2, p2+len-1), p2+len
elseif tp == lj_ktab_type.int then
return read_lj_uleb128(pstr, p2)
elseif tp == lj_ktab_type.num then
return read_lj_double(pstr, p2)
elseif tp == lj_ktab_type["nil"] then
return nil, p2
elseif tp == lj_ktab_type["true"] then
return true, p2
elseif tp == lj_ktab_type["false"] then
return false, p2
end
return nil, ("Bad KTABK constant type %d"):format(tp)
end
local function lj_parse_ktab(phead, pos, pstr, bcflags)
local tab = {}
local narr, p2 = read_lj_uleb128(pstr, pos)
local nhash, p3 = read_lj_uleb128(pstr, p2)
pos = p3
for k = 0, narr-1 do
tab[k], pos = lj_parse_ktabk(phead, pos, pstr, bcflags)
if tab[k] == nil and type(pos) ~= "number" then
return nil, pos
end
end
for _ = 1, nhash do
local k, p4 = lj_parse_ktabk(phead, pos, pstr, bcflags)
if k == nil then
if type(p4) == "number" then
return nil, "Table index is nil"
end
return nil, p4
end
tab[k], pos = lj_parse_ktabk(phead, p4, pstr, bcflags)
if tab[k] == nil and type(pos) ~= "number" then
return nil, pos
end
end
return tab, pos
end
local lj_kgc_type = {
child = 0,
tab = 1,
str = 5,
}
local function lj_parse_kgc(phead, pos, pstr, bcflags)
local gclist = {}
for k = phead.numkgc-1, 0, -1 do
local tp, p2 = read_lj_uleb128(pstr, pos)
if tp >= lj_kgc_type.str then
local len = tp-lj_kgc_type.str
local str = string.sub(pstr, p2, p2+len-1)
gclist[k] = str
pos = p2+len
elseif tp == lj_kgc_type.tab then
local tbl, p3 = lj_parse_ktab(phead, p2, pstr, bcflags)
if not tbl then
return nil, p3
end
gclist[k] = tbl
pos = p3
elseif tp == lj_kgc_type.child then
local idx = bcflags.top - 1
if idx < 0 then
return nil, "Child stack underflow"
end
gclist[k] = idx
bcflags.top = idx
pos = p2
else
return nil, ("Bad constant type %d"):format(tp)
end
end
return gclist, pos
end
local function lj_parse_kn(phead, pos, pstr, bcflags)
local numlist = {}
for k = 0, phead.numkn-1 do
local n, p2 = read_lj_int_or_double(pstr, pos)
if not n then
return nil, p2
end
numlist[k] = n
pos = p2
end
return numlist, pos
end
local function lj_parse_pdata_body(phead, pstr, bcflags)
local pdata = {
header = phead,
}
local pos = 1
pdata.bytecode, pos = lj_parse_bytecode(phead, pos, pstr, bcflags)
if not pdata.bytecode then
return nil, pos
end
pdata.uv, pos = lj_parse_uv(phead, pos, pstr, bcflags)
if not pdata.uv then
return nil, pos
end
pdata.kgc, pos = lj_parse_kgc(phead, pos, pstr, bcflags)
if not pdata.kgc then
return nil, pos
end
pdata.kn, pos = lj_parse_kn(phead, pos, pstr, bcflags)
if not pdata.kn then
return nil, pos
end
return pdata
end
local lj_proto_flags = {}
local function lj_parse_proto(pstr, bcflags)
local phead, pos = {}
phead.flags = readflags(read_u8(pstr, 1), lj_proto_flags)
phead.numparams = read_u8(pstr, 2)
phead.framesize = read_u8(pstr, 3)
phead.numuv = read_u8(pstr, 4)
phead.numkgc, pos = read_lj_uleb128(pstr, 5)
phead.numkn, pos = read_lj_uleb128(pstr, pos)
phead.numbc, pos = read_lj_uleb128(pstr, pos)
if not bcflags.strip then
phead.debuglen, pos = read_lj_uleb128(pstr, pos)
if phead.debuglen > 0 then
phead.firstline, pos = read_lj_uleb128(pstr, pos)
phead.numline, pos = read_lj_uleb128(pstr, pos)
end
end
return lj_parse_pdata_body(phead, string.sub(pstr, pos), bcflags)
end
local lj_bcdump_flags = {
be = 1,
strip = 2,
ffi = 4,
fr2 = 8,
}
local function parse_lj2(dump)
local flags, pos = read_lj_uleb128(dump, 1)
flags = readflags(flags, lj_bcdump_flags)
if flags.ffi then
return nil, "LuaJIT bytecode dump with FFI is not supported"
elseif flags.be then
return nil, "Big-endian LuaJIT bytecode is not supported"
end
local chunkname
if not flags.strip then
local cname, p2 = lj_read_nbytes(dump, pos)
pos = p2
if cname ~= "" then
chunkname = cname
end
end
local prototypes = {}
flags.top = 0
while true do
local proto, p2 = lj_read_nbytes(dump, pos)
pos = p2
if proto == "" then
break
end
local pdata, err = lj_parse_proto(proto, flags)
if pdata == nil then
return nil, err
end
flags.top = flags.top + 1
table.insert(prototypes, pdata)
end
flags.top = nil
return {
chunkname = chunkname,
prototypes = prototypes,
}
end
local function parse_lj(dump)
local version = string.byte(dump, 1, 1)
if version == 2 then
return parse_lj2(string.sub(dump, 2))
end
return nil, "Unsupported LuaJIT bytecode version"
end
local function ensure_result(st, ...)
local count = select("#", ...)
if count > 0 and ... ~= nil then
return st, ...
end
return ...
end
--- Try to parse a bytecode dump.
-- @tparam string|function dump The bytecode input or the function to read.
-- @return[1] "luajit" If `dump` is valid LuaJIT bytecode.
-- @treturn[1] ... Data parsed from the bytecode.
-- @treturn[2] nil If the dump cannot be parsed.
-- @treturn[2] string A message indicating the error.
function bc.parse(dump)
local tp = type(dump)
if tp == "function" then
return bc.parse(string.dump(dump))
elseif tp ~= "string" then
return nil, "Invalid bytecode dump type"
end
local header = string.sub(dump, 1, 3)
if header == "\27LJ" then
return ensure_result("luajit", parse_lj(string.sub(dump, 4)))
end
return nil, "Unsupported bytecode dump format"
end
local escape_string_table = {
["\n"] = [[\n]],
["\r"] = [[\r]],
["\0"] = [[\z]],
["\""] = [[\"]],
["\\"] = [[\\]],
}
local function escape_string(str)
return (string.gsub(str, "[%z\1-\31\127-\255]", function(c)
if escape_string_table[c] then
return escape_string_table[c]
end
return ([[\%d]]):format(string.byte(c))
end))
end
local function lj_value_tostring(proto, line, value)
local vt, vv = value.type, value.value
local vs = ("%3d"):format(vv)
if vt == nil then
return " "
elseif vt == "jump" then
return ("=> %04d"):format(line+vv-32767)
elseif vt == "str" then
local ref = proto.kgc[vv]
if type(ref) == "string" then
return vs, ([["%s"]]):format(escape_string(ref))
end
elseif vt == "func" then
local ref = proto.kgc[vv]
if type(ref) == "number" then
return vs, ("BYTECODE %d"):format(ref)
end
elseif vt == "num" then
local num = proto.kn[vv]
if type(num) == "number" then
return vs, tostring(num)
end
elseif vt == "lits" then
if vv >= 32768 then
return ("%3d"):format(vv-65536)
end
end
return vs
end
local function lj_proto_tostring(index, proto)
local st = {("-- BYTECODE -- %d"):format(index-1)}
local jmp_target = {}
for ln, line in ipairs(proto.bytecode) do
if line[3] and line[3].type == "jump" then
jmp_target[ln+line[3].value-32767] = true
end
end
local kgc_count = #proto.kgc
if proto.kgc[0] == nil then
kgc_count = -1
end
for id = 0, kgc_count do
local val = proto.kgc[id]
if type(val) == "string" then
val = ([["%s"]]):format(escape_string(val))
elseif type(val) == "number" then
val = ("BYTECODE %d"):format(val)
elseif type(val) == "table" then
val = "[table]"
else
val = "???"
end
table.insert(st, ("%-7s %6d %s"):format("KGC", id, val))
end
local kn_count = #proto.kn
if proto.kn[0] == nil then
kn_count = -1
end
for id = 0, kn_count do
table.insert(st, ("%-7s %6d %s"):format("KN", id, proto.kn[id]))
end
for ln, line in ipairs(proto.bytecode) do
local lt = {("%04d"):format(ln)}
if jmp_target[ln] then
table.insert(lt, "=>")
else
table.insert(lt, " ")
end
table.insert(lt, ("%-6s"):format(line[1]))
table.insert(lt, lj_value_tostring(proto, ln, line[2]))
if line[4] then
local bs, bn = lj_value_tostring(proto, ln, line[3])
local cs, cn = lj_value_tostring(proto, ln, line[4])
table.insert(lt, ("%3d %3d"):format(bs, cs))
if bn or cn then
table.insert(lt, ";")
table.insert(lt, bn)
table.insert(lt, cn)
end
else
local ds, dn = lj_value_tostring(proto, ln, line[3])
table.insert(lt, ("%-7s"):format(ds))
if dn then
table.insert(lt, ("; %s"):format(dn))
end
end
table.insert(st, table.concat(lt, " "))
end
return table.concat(st, "\n")
end
--- Try to format a bytecode dump.
-- @tparam string|function dump The bytecode input of the function to read.
-- @return[1] "luajit" If `dump` is valid LuaJIT bytecode.
-- @treturn[1] string A string describing the bytecode dump. The format is similar to that of `luajit -bl`
-- @treturn[2] nil If the dump cannot be parsed.
-- @treturn[2] string A message indicating the error.
function bc.tostring(dump)
local tp, data = bc.parse(dump)
if tp == "luajit" then
local st = {}
for k, proto in ipairs(data.prototypes) do
table.insert(st, lj_proto_tostring(k, proto))
end
return tp, table.concat(st, "\n\n")
end
return nil, data
end
if true then -- debugging only
minetest.register_chatcommand("atdoc_format_function", {
params = "<lua code>",
deescription = "Execute the given lua code and dump the resulting function",
privs = {server = true},
func = function(_, param)
local f, err = loadstring(param)
if not f then
return false, err
end
local st, val = pcall(f)
if not st then
return false, val
end
local tp, desc = bc.tostring(val)
if not tp then
return false, desc
end
return true, desc
end,
})
end
return bc