[contrib] Add preprocessor hardwiring to freestanding.py

dev
Nick Terrell 2020-08-10 23:09:59 -07:00
parent 79ded1b4a9
commit 1c3cb2c05c
1 changed files with 406 additions and 14 deletions

View File

@ -17,6 +17,7 @@ import shutil
import sys import sys
from typing import Optional from typing import Optional
INCLUDED_SUBDIRS = ["common", "compress", "decompress"] INCLUDED_SUBDIRS = ["common", "compress", "decompress"]
SKIPPED_FILES = [ SKIPPED_FILES = [
@ -46,18 +47,360 @@ class FileLines(object):
f.write("".join(self.lines)) f.write("".join(self.lines))
class PartialPreprocessor(object):
"""
Looks for simple ifdefs and ifndefs and replaces them.
Handles && and ||.
Has fancy logic to handle translating elifs to ifs.
Only looks for macros in the first part of the expression with no
parens.
Does not handle multi-line macros (only looks in first line).
"""
def __init__(self, defs: [(str, Optional[str])], replaces: [(str, str)], undefs: [str]):
MACRO_GROUP = r"(?P<macro>[a-zA-Z_][a-zA-Z_0-9]*)"
ELIF_GROUP = r"(?P<elif>el)?"
OP_GROUP = r"(?P<op>&&|\|\|)?"
self._defs = {macro:value for macro, value in defs}
self._replaces = {macro:value for macro, value in replaces}
self._defs.update(self._replaces)
self._undefs = set(undefs)
self._define = re.compile(r"\s*#\s*define")
self._if = re.compile(r"\s*#\s*if")
self._elif = re.compile(r"\s*#\s*(?P<elif>el)if")
self._else = re.compile(r"\s*#\s*(?P<else>else)")
self._endif = re.compile(r"\s*#\s*endif")
self._ifdef = re.compile(fr"\s*#\s*if(?P<not>n)?def {MACRO_GROUP}\s*")
self._if_defined = re.compile(
fr"\s*#\s*{ELIF_GROUP}if\s+(?P<not>!)?\s*defined\s*\(\s*{MACRO_GROUP}\s*\)\s*{OP_GROUP}"
)
self._if_defined_value = re.compile(
fr"\s*#\s*if\s+defined\s*\(\s*{MACRO_GROUP}\s*\)\s*"
fr"(?P<op>&&)\s*"
fr"(?P<openp>\()?\s*"
fr"(?P<macro2>[a-zA-Z_][a-zA-Z_0-9]*)\s*"
fr"(?P<cmp>[=><!]+)\s*"
fr"(?P<value>[0-9]*)\s*"
fr"(?P<closep>\))?\s*"
)
self._c_comment = re.compile(r"/\*.*?\*/")
self._cpp_comment = re.compile(r"//")
def _log(self, *args, **kwargs):
print(*args, **kwargs)
def _strip_comments(self, line):
# First strip c-style comments (may include //)
while True:
m = self._c_comment.search(line)
if m is None:
break
line = line[:m.start()] + line[m.end():]
# Then strip cpp-style comments
m = self._cpp_comment.search(line)
if m is not None:
line = line[:m.start()]
return line
def _fixup_indentation(self, macro, replace: [str]):
if len(replace) == 0:
return replace
if len(replace) == 1 and self._define.match(replace[0]) is None:
# If there is only one line, only replace defines
return replace
all_pound = True
for line in replace:
if not line.startswith('#'):
all_pound = False
if all_pound:
replace = [line[1:] for line in replace]
min_spaces = len(replace[0])
for line in replace:
spaces = 0
for i, c in enumerate(line):
if c != ' ':
# Non-preprocessor line ==> skip the fixup
if not all_pound and c != '#':
return replace
spaces = i
break
min_spaces = min(min_spaces, spaces)
replace = [line[min_spaces:] for line in replace]
if all_pound:
replace = ["#" + line for line in replace]
return replace
def _handle_if_block(self, macro, idx, is_true, prepend):
"""
Remove the #if or #elif block starting on this line.
"""
REMOVE_ONE = 0
KEEP_ONE = 1
REMOVE_REST = 2
if is_true:
state = KEEP_ONE
else:
state = REMOVE_ONE
line = self._inlines[idx]
is_if = self._if.match(line) is not None
assert is_if or self._elif.match(line) is not None
depth = 0
start_idx = idx
idx += 1
replace = prepend
finished = False
while idx < len(self._inlines):
line = self._inlines[idx]
# Nested if statement
if self._if.match(line):
depth += 1
idx += 1
continue
# We're inside a nested statement
if depth > 0:
if self._endif.match(line):
depth -= 1
idx += 1
continue
# We're at the original depth
# Looking only for an endif.
# We've found a true statement, but haven't
# completely elided the if block, so we just
# remove the remainder.
if state == REMOVE_REST:
if self._endif.match(line):
if is_if:
# Remove the endif because we took the first if
idx += 1
finished = True
break
idx += 1
continue
if state == KEEP_ONE:
m = self._elif.match(line)
if self._endif.match(line):
replace += self._inlines[start_idx + 1:idx]
idx += 1
finished = True
break
if self._elif.match(line) or self._else.match(line):
replace += self._inlines[start_idx + 1:idx]
state = REMOVE_REST
idx += 1
continue
if state == REMOVE_ONE:
m = self._elif.match(line)
if m is not None:
if is_if:
idx += 1
b = m.start('elif')
e = m.end('elif')
assert e - b == 2
replace.append(line[:b] + line[e:])
finished = True
break
m = self._else.match(line)
if m is not None:
if is_if:
idx += 1
while self._endif.match(self._inlines[idx]) is None:
replace.append(self._inlines[idx])
idx += 1
idx += 1
finished = True
break
if self._endif.match(line):
if is_if:
# Remove the endif because no other elifs
idx += 1
finished = True
break
idx += 1
continue
if not finished:
raise RuntimeError("Unterminated if block!")
replace = self._fixup_indentation(macro, replace)
self._log(f"\tHardwiring {macro}")
if start_idx > 0:
self._log(f"\t\t {self._inlines[start_idx - 1][:-1]}")
for x in range(start_idx, idx):
self._log(f"\t\t- {self._inlines[x][:-1]}")
for line in replace:
self._log(f"\t\t+ {line[:-1]}")
if idx < len(self._inlines):
self._log(f"\t\t {self._inlines[idx][:-1]}")
return idx, replace
def _preprocess_once(self):
outlines = []
idx = 0
changed = False
while idx < len(self._inlines):
line = self._inlines[idx]
sline = self._strip_comments(line)
m = self._ifdef.fullmatch(sline)
if m is None:
m = self._if_defined_value.fullmatch(sline)
if m is None:
m = self._if_defined.match(sline)
if m is None:
outlines.append(line)
idx += 1
continue
groups = m.groupdict()
macro = groups['macro']
ifdef = groups.get('not') is None
elseif = groups.get('elif') is not None
op = groups.get('op')
macro2 = groups.get('macro2')
cmp = groups.get('cmp')
value = groups.get('value')
openp = groups.get('openp')
closep = groups.get('closep')
if not (macro in self._defs or macro in self._undefs):
outlines.append(line)
idx += 1
continue
defined = macro in self._defs
is_true = (ifdef == defined)
resolved = True
if op is not None:
if op == '&&':
resolved = not is_true
else:
assert op == '||'
resolved = is_true
if macro2 is not None and not resolved:
assert ifdef and defined and op == '&&' and cmp is not None
# If the statment is true, but we have a single value check, then
# check the value.
defined_value = self._defs[macro]
are_ints = True
try:
defined_value = int(defined_value)
value = int(value)
except TypeError:
are_ints = False
except ValueError:
are_ints = False
if (
macro == macro2 and
((openp is None) == (closep is None)) and
are_ints
):
resolved = True
if cmp == '<':
is_true = defined_value < value
elif cmp == '<=':
is_true = defined_value <= value
elif cmp == '==':
is_true = defined_value == value
elif cmp == '!=':
is_true = defined_value != value
elif cmp == '>=':
is_true = defined_value >= value
elif cmp == '>':
is_true = defined_value > value
else:
resolved = False
if op is not None and not resolved:
# Remove the first op in the line + spaces
if op == '&&':
opre = op
else:
assert op == '||'
opre = r'\|\|'
needle = re.compile(fr"(?P<if>\s*#\s*(el)?if\s+).*?(?P<op>{opre}\s*)")
match = needle.match(line)
assert match is not None
newline = line[:match.end('if')] + line[match.end('op'):]
self._log(f"\tHardwiring partially resolved {macro}")
self._log(f"\t\t- {line[:-1]}")
self._log(f"\t\t+ {newline[:-1]}")
outlines.append(newline)
idx += 1
continue
# Skip any statements we cannot fully compute
if not resolved:
outlines.append(line)
idx += 1
continue
prepend = []
if macro in self._replaces:
assert not ifdef
assert op is None
value = self._replaces.pop(macro)
prepend = [f"#define {macro} {value}\n"]
idx, replace = self._handle_if_block(macro, idx, is_true, prepend)
outlines += replace
changed = True
return changed, outlines
def preprocess(self, filename):
with open(filename, 'r') as f:
self._inlines = f.readlines()
changed = True
iters = 0
while changed:
iters += 1
changed, outlines = self._preprocess_once()
self._inlines = outlines
with open(filename, 'w') as f:
f.write(''.join(self._inlines))
class Freestanding(object): class Freestanding(object):
def __init__( def __init__(
self,zstd_deps: str, source_lib: str, output_lib: str, self,zstd_deps: str, source_lib: str, output_lib: str,
external_xxhash: bool, rewritten_includes: [(str, str)], external_xxhash: bool, xxh64_state: Optional[str],
defs: [(str, Optional[str])], undefs: [str], excludes: [str] xxh64_prefix: Optional[str], rewritten_includes: [(str, str)],
defs: [(str, Optional[str])], replaces: [(str, str)],
undefs: [str], excludes: [str]
): ):
self._zstd_deps = zstd_deps self._zstd_deps = zstd_deps
self._src_lib = source_lib self._src_lib = source_lib
self._dst_lib = output_lib self._dst_lib = output_lib
self._external_xxhash = external_xxhash self._external_xxhash = external_xxhash
self._xxh64_state = xxh64_state
self._xxh64_prefix = xxh64_prefix
self._rewritten_includes = rewritten_includes self._rewritten_includes = rewritten_includes
self._defs = defs self._defs = defs
self._replaces = replaces
self._undefs = undefs self._undefs = undefs
self._excludes = excludes self._excludes = excludes
@ -121,14 +464,10 @@ class Freestanding(object):
file = FileLines(filepath) file = FileLines(filepath)
def _hardwire_defines(self): def _hardwire_defines(self):
self._log("Hardwiring defined macros") self._log("Hardwiring macros")
for (name, value) in self._defs: partial_preprocessor = PartialPreprocessor(self._defs, self._replaces, self._undefs)
self._log(f"\tHardwiring: #define {name} {value}") for filepath in self._dst_lib_file_paths():
self._hardwire_preprocessor(name, value=value) partial_preprocessor.preprocess(filepath)
self._log("Hardwiring undefined macros")
for name in self._undefs:
self._log(f"\tHardwiring: #undef {name}")
self._hardwire_preprocessor(name, undef=True)
def _remove_excludes(self): def _remove_excludes(self):
self._log("Removing excluded sections") self._log("Removing excluded sections")
@ -180,15 +519,47 @@ class Freestanding(object):
for original, rewritten in self._rewritten_includes: for original, rewritten in self._rewritten_includes:
self._rewrite_include(original, rewritten) self._rewrite_include(original, rewritten)
def _replace_xxh64_prefix(self):
if self._xxh64_prefix is None:
return
self._log(f"Replacing XXH64 prefix with {self._xxh64_prefix}")
replacements = []
if self._xxh64_state is not None:
replacements.append(
(re.compile(r"([^\w]|^)(?P<orig>XXH64_state_t)([^\w]|$)"), self._xxh64_state)
)
if self._xxh64_prefix is not None:
replacements.append(
(re.compile(r"([^\w]|^)(?P<orig>XXH64)_"), self._xxh64_prefix)
)
for filepath in self._dst_lib_file_paths():
file = FileLines(filepath)
for i, line in enumerate(file.lines):
modified = False
for regex, replacement in replacements:
match = regex.search(line)
while match is not None:
modified = True
b = match.start('orig')
e = match.end('orig')
line = line[:b] + replacement + line[e:]
match = regex.search(line)
if modified:
self._log(f"\t- {file.lines[i][:-1]}")
self._log(f"\t+ {line[:-1]}")
file.lines[i] = line
file.write()
def go(self): def go(self):
self._copy_source_lib() self._copy_source_lib()
self._copy_zstd_deps() self._copy_zstd_deps()
self._hardwire_defines() self._hardwire_defines()
self._remove_excludes() self._remove_excludes()
self._rewrite_includes() self._rewrite_includes()
self._replace_xxh64_prefix()
def parse_defines(defines: [str]) -> [(str, Optional[str])]: def parse_optional_pair(defines: [str]) -> [(str, Optional[str])]:
output = [] output = []
for define in defines: for define in defines:
parsed = define.split('=') parsed = define.split('=')
@ -201,7 +572,7 @@ def parse_defines(defines: [str]) -> [(str, Optional[str])]:
return output return output
def parse_rewritten_includes(rewritten_includes: [str]) -> [(str, str)]: def parse_pair(rewritten_includes: [str]) -> [(str, str)]:
output = [] output = []
for rewritten_include in rewritten_includes: for rewritten_include in rewritten_includes:
parsed = rewritten_include.split('=') parsed = rewritten_include.split('=')
@ -219,9 +590,12 @@ def main(name, args):
parser.add_argument("--source-lib", default="../../lib", help="Location of the zstd library") parser.add_argument("--source-lib", default="../../lib", help="Location of the zstd library")
parser.add_argument("--output-lib", default="./freestanding_lib", help="Where to output the freestanding zstd library") parser.add_argument("--output-lib", default="./freestanding_lib", help="Where to output the freestanding zstd library")
parser.add_argument("--xxhash", default=None, help="Alternate external xxhash include e.g. --xxhash='<xxhash.h>'. If set xxhash is not included.") parser.add_argument("--xxhash", default=None, help="Alternate external xxhash include e.g. --xxhash='<xxhash.h>'. If set xxhash is not included.")
parser.add_argument("--xxh64-state", default=None, help="Alternate XXH64 state type (excluding _) e.g. --xxh64-state='struct xxh64_state'")
parser.add_argument("--xxh64-prefix", default=None, help="Alternate XXH64 function prefix (excluding _) e.g. --xxh64-prefix=xxh64")
parser.add_argument("--rewrite-include", default=[], dest="rewritten_includes", action="append", help="Rewrite an include REGEX=NEW (e.g. '<stddef\\.h>=<linux/types.h>')") parser.add_argument("--rewrite-include", default=[], dest="rewritten_includes", action="append", help="Rewrite an include REGEX=NEW (e.g. '<stddef\\.h>=<linux/types.h>')")
parser.add_argument("-D", "--define", default=[], dest="defs", action="append", help="Pre-define this macro (can be passed multiple times)") parser.add_argument("-D", "--define", default=[], dest="defs", action="append", help="Pre-define this macro (can be passed multiple times)")
parser.add_argument("-U", "--undefine", default=[], dest="undefs", action="append", help="Pre-undefine this macro (can be passed mutliple times)") parser.add_argument("-U", "--undefine", default=[], dest="undefs", action="append", help="Pre-undefine this macro (can be passed mutliple times)")
parser.add_argument("-R", "--replace", default=[], dest="replaces", action="append", help="Pre-define this macro and replace the first ifndef block with its definition")
parser.add_argument("-E", "--exclude", default=[], dest="excludes", action="append", help="Exclude all lines between 'BEGIN <EXCLUDE>' and 'END <EXCLUDE>'") parser.add_argument("-E", "--exclude", default=[], dest="excludes", action="append", help="Exclude all lines between 'BEGIN <EXCLUDE>' and 'END <EXCLUDE>'")
args = parser.parse_args(args) args = parser.parse_args(args)
@ -229,22 +603,37 @@ def main(name, args):
if "ZSTD_MULTITHREAD" not in args.undefs: if "ZSTD_MULTITHREAD" not in args.undefs:
args.undefs.append("ZSTD_MULTITHREAD") args.undefs.append("ZSTD_MULTITHREAD")
args.defs = parse_defines(args.defs) args.defs = parse_optional_pair(args.defs)
for name, _ in args.defs: for name, _ in args.defs:
if name in args.undefs: if name in args.undefs:
raise RuntimeError(f"{name} is both defined and undefined!") raise RuntimeError(f"{name} is both defined and undefined!")
args.rewritten_includes = parse_rewritten_includes(args.rewritten_includes) args.replaces = parse_pair(args.replaces)
for name, _ in args.replaces:
if name in args.undefs or name in args.defs:
raise RuntimeError(f"{name} is both replaced and (un)defined!")
args.rewritten_includes = parse_pair(args.rewritten_includes)
external_xxhash = False external_xxhash = False
if args.xxhash is not None: if args.xxhash is not None:
external_xxhash = True external_xxhash = True
args.rewritten_includes.append(('"(\\.\\./common/)?xxhash.h"', args.xxhash)) args.rewritten_includes.append(('"(\\.\\./common/)?xxhash.h"', args.xxhash))
if args.xxh64_prefix is not None:
if not external_xxhash:
raise RuntimeError("--xxh64-prefix may only be used with --xxhash provided")
if args.xxh64_state is not None:
if not external_xxhash:
raise RuntimeError("--xxh64-state may only be used with --xxhash provided")
print(args.zstd_deps) print(args.zstd_deps)
print(args.output_lib) print(args.output_lib)
print(args.source_lib) print(args.source_lib)
print(args.xxhash) print(args.xxhash)
print(args.xxh64_state)
print(args.xxh64_prefix)
print(args.rewritten_includes) print(args.rewritten_includes)
print(args.defs) print(args.defs)
print(args.undefs) print(args.undefs)
@ -254,8 +643,11 @@ def main(name, args):
args.source_lib, args.source_lib,
args.output_lib, args.output_lib,
external_xxhash, external_xxhash,
args.xxh64_state,
args.xxh64_prefix,
args.rewritten_includes, args.rewritten_includes,
args.defs, args.defs,
args.replaces,
args.undefs, args.undefs,
args.excludes args.excludes
).go() ).go()