add memcpy and memset intrinsics
This commit is contained in:
parent
bdca82ea66
commit
6d9119fcd9
@ -1944,8 +1944,11 @@ static TypeTableEntry *analyze_while_expr(CodeGen *g, ImportTableEntry *import,
|
||||
if (resolved_type->id != TypeTableEntryIdInvalid) {
|
||||
assert(resolved_type->id == TypeTableEntryIdBool);
|
||||
bool constant_cond_value = number_literal.data.x_uint;
|
||||
if (constant_cond_value && !node->codegen_node->data.while_node.contains_break) {
|
||||
expr_return_type = g->builtin_types.entry_unreachable;
|
||||
if (constant_cond_value) {
|
||||
node->codegen_node->data.while_node.condition_always_true = true;
|
||||
if (!node->codegen_node->data.while_node.contains_break) {
|
||||
expr_return_type = g->builtin_types.entry_unreachable;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2085,13 +2088,74 @@ static TypeTableEntry *analyze_builtin_fn_call_expr(CodeGen *g, ImportTableEntry
|
||||
builtin_fn->param_count, actual_param_count));
|
||||
}
|
||||
|
||||
for (int i = 0; i < actual_param_count; i += 1) {
|
||||
AstNode *child = node->data.fn_call_expr.params.at(i);
|
||||
TypeTableEntry *expected_param_type = builtin_fn->param_types[i];
|
||||
analyze_expression(g, import, context, expected_param_type, child);
|
||||
}
|
||||
switch (builtin_fn->id) {
|
||||
case BuiltinFnIdInvalid:
|
||||
zig_unreachable();
|
||||
case BuiltinFnIdArithmeticWithOverflow:
|
||||
for (int i = 0; i < actual_param_count; i += 1) {
|
||||
AstNode *child = node->data.fn_call_expr.params.at(i);
|
||||
TypeTableEntry *expected_param_type = builtin_fn->param_types[i];
|
||||
analyze_expression(g, import, context, expected_param_type, child);
|
||||
}
|
||||
return builtin_fn->return_type;
|
||||
case BuiltinFnIdMemcpy:
|
||||
{
|
||||
AstNode *dest_node = node->data.fn_call_expr.params.at(0);
|
||||
AstNode *src_node = node->data.fn_call_expr.params.at(1);
|
||||
AstNode *len_node = node->data.fn_call_expr.params.at(2);
|
||||
TypeTableEntry *dest_type = analyze_expression(g, import, context, nullptr, dest_node);
|
||||
TypeTableEntry *src_type = analyze_expression(g, import, context, nullptr, src_node);
|
||||
analyze_expression(g, import, context, builtin_fn->param_types[2], len_node);
|
||||
|
||||
return builtin_fn->return_type;
|
||||
if (dest_type->id != TypeTableEntryIdInvalid &&
|
||||
dest_type->id != TypeTableEntryIdPointer)
|
||||
{
|
||||
add_node_error(g, dest_node,
|
||||
buf_sprintf("expected pointer argument, got '%s'", buf_ptr(&dest_type->name)));
|
||||
}
|
||||
|
||||
if (src_type->id != TypeTableEntryIdInvalid &&
|
||||
src_type->id != TypeTableEntryIdPointer)
|
||||
{
|
||||
add_node_error(g, src_node,
|
||||
buf_sprintf("expected pointer argument, got '%s'", buf_ptr(&src_type->name)));
|
||||
}
|
||||
|
||||
if (dest_type->id == TypeTableEntryIdPointer &&
|
||||
src_type->id == TypeTableEntryIdPointer)
|
||||
{
|
||||
uint64_t dest_align_bits = dest_type->data.pointer.child_type->align_in_bits;
|
||||
uint64_t src_align_bits = src_type->data.pointer.child_type->align_in_bits;
|
||||
if (dest_align_bits != src_align_bits) {
|
||||
add_node_error(g, dest_node, buf_sprintf(
|
||||
"misaligned memcpy, '%s' has alignment '%" PRIu64 ", '%s' has alignment %" PRIu64,
|
||||
buf_ptr(&dest_type->name), dest_align_bits / 8,
|
||||
buf_ptr(&src_type->name), src_align_bits / 8));
|
||||
}
|
||||
}
|
||||
|
||||
return builtin_fn->return_type;
|
||||
}
|
||||
case BuiltinFnIdMemset:
|
||||
{
|
||||
AstNode *dest_node = node->data.fn_call_expr.params.at(0);
|
||||
AstNode *char_node = node->data.fn_call_expr.params.at(1);
|
||||
AstNode *len_node = node->data.fn_call_expr.params.at(2);
|
||||
TypeTableEntry *dest_type = analyze_expression(g, import, context, nullptr, dest_node);
|
||||
analyze_expression(g, import, context, builtin_fn->param_types[1], char_node);
|
||||
analyze_expression(g, import, context, builtin_fn->param_types[2], len_node);
|
||||
|
||||
if (dest_type->id != TypeTableEntryIdInvalid &&
|
||||
dest_type->id != TypeTableEntryIdPointer)
|
||||
{
|
||||
add_node_error(g, dest_node,
|
||||
buf_sprintf("expected pointer argument, got '%s'", buf_ptr(&dest_type->name)));
|
||||
}
|
||||
|
||||
return builtin_fn->return_type;
|
||||
}
|
||||
}
|
||||
zig_unreachable();
|
||||
} else {
|
||||
add_node_error(g, node,
|
||||
buf_sprintf("invalid builtin function: '%s'", buf_ptr(name)));
|
||||
|
@ -151,6 +151,8 @@ struct FnTableEntry {
|
||||
enum BuiltinFnId {
|
||||
BuiltinFnIdInvalid,
|
||||
BuiltinFnIdArithmeticWithOverflow,
|
||||
BuiltinFnIdMemcpy,
|
||||
BuiltinFnIdMemset,
|
||||
};
|
||||
|
||||
struct BuiltinFnEntry {
|
||||
@ -354,6 +356,7 @@ struct ImportNode {
|
||||
};
|
||||
|
||||
struct WhileNode {
|
||||
bool condition_always_true;
|
||||
bool contains_break;
|
||||
};
|
||||
|
||||
|
142
src/codegen.cpp
142
src/codegen.cpp
@ -171,6 +171,67 @@ static LLVMValueRef gen_builtin_fn_call_expr(CodeGen *g, AstNode *node) {
|
||||
|
||||
return overflow_bit;
|
||||
}
|
||||
case BuiltinFnIdMemcpy:
|
||||
{
|
||||
int fn_call_param_count = node->data.fn_call_expr.params.length;
|
||||
assert(fn_call_param_count == 3);
|
||||
|
||||
AstNode *dest_node = node->data.fn_call_expr.params.at(0);
|
||||
TypeTableEntry *dest_type = get_expr_type(dest_node);
|
||||
|
||||
LLVMValueRef dest_ptr = gen_expr(g, dest_node);
|
||||
LLVMValueRef src_ptr = gen_expr(g, node->data.fn_call_expr.params.at(1));
|
||||
LLVMValueRef len_val = gen_expr(g, node->data.fn_call_expr.params.at(2));
|
||||
|
||||
LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0);
|
||||
|
||||
add_debug_source_node(g, node);
|
||||
LLVMValueRef dest_ptr_casted = LLVMBuildBitCast(g->builder, dest_ptr, ptr_u8, "");
|
||||
LLVMValueRef src_ptr_casted = LLVMBuildBitCast(g->builder, src_ptr, ptr_u8, "");
|
||||
|
||||
uint64_t align_in_bytes = dest_type->data.pointer.child_type->align_in_bits / 8;
|
||||
|
||||
LLVMValueRef params[] = {
|
||||
dest_ptr_casted, // dest pointer
|
||||
src_ptr_casted, // source pointer
|
||||
len_val, // byte count
|
||||
LLVMConstInt(LLVMInt32Type(), align_in_bytes, false), // align in bytes
|
||||
LLVMConstNull(LLVMInt1Type()), // is volatile
|
||||
};
|
||||
|
||||
LLVMBuildCall(g->builder, builtin_fn->fn_val, params, 5, "");
|
||||
return nullptr;
|
||||
}
|
||||
case BuiltinFnIdMemset:
|
||||
{
|
||||
int fn_call_param_count = node->data.fn_call_expr.params.length;
|
||||
assert(fn_call_param_count == 3);
|
||||
|
||||
AstNode *dest_node = node->data.fn_call_expr.params.at(0);
|
||||
TypeTableEntry *dest_type = get_expr_type(dest_node);
|
||||
|
||||
LLVMValueRef dest_ptr = gen_expr(g, dest_node);
|
||||
LLVMValueRef char_val = gen_expr(g, node->data.fn_call_expr.params.at(1));
|
||||
LLVMValueRef len_val = gen_expr(g, node->data.fn_call_expr.params.at(2));
|
||||
|
||||
LLVMTypeRef ptr_u8 = LLVMPointerType(LLVMInt8Type(), 0);
|
||||
|
||||
add_debug_source_node(g, node);
|
||||
LLVMValueRef dest_ptr_casted = LLVMBuildBitCast(g->builder, dest_ptr, ptr_u8, "");
|
||||
|
||||
uint64_t align_in_bytes = dest_type->data.pointer.child_type->align_in_bits / 8;
|
||||
|
||||
LLVMValueRef params[] = {
|
||||
dest_ptr_casted, // dest pointer
|
||||
char_val, // source pointer
|
||||
len_val, // byte count
|
||||
LLVMConstInt(LLVMInt32Type(), align_in_bytes, false), // align in bytes
|
||||
LLVMConstNull(LLVMInt1Type()), // is volatile
|
||||
};
|
||||
|
||||
LLVMBuildCall(g->builder, builtin_fn->fn_val, params, 5, "");
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
zig_unreachable();
|
||||
}
|
||||
@ -1376,23 +1437,35 @@ static LLVMValueRef gen_while_expr(CodeGen *g, AstNode *node) {
|
||||
assert(node->data.while_expr.condition);
|
||||
assert(node->data.while_expr.body);
|
||||
|
||||
if (get_expr_type(node)->id == TypeTableEntryIdUnreachable) {
|
||||
// generate a forever loop. guarantees no break statements
|
||||
bool condition_always_true = node->codegen_node->data.while_node.condition_always_true;
|
||||
bool contains_break = node->codegen_node->data.while_node.contains_break;
|
||||
if (condition_always_true) {
|
||||
// generate a forever loop
|
||||
|
||||
LLVMBasicBlockRef body_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileBody");
|
||||
LLVMBasicBlockRef end_block = nullptr;
|
||||
if (contains_break) {
|
||||
end_block = LLVMAppendBasicBlock(g->cur_fn->fn_value, "WhileEnd");
|
||||
}
|
||||
|
||||
add_debug_source_node(g, node);
|
||||
LLVMBuildBr(g->builder, body_block);
|
||||
|
||||
LLVMPositionBuilderAtEnd(g->builder, body_block);
|
||||
g->break_block_stack.append(end_block);
|
||||
g->continue_block_stack.append(body_block);
|
||||
gen_expr(g, node->data.while_expr.body);
|
||||
g->break_block_stack.pop();
|
||||
g->continue_block_stack.pop();
|
||||
|
||||
if (get_expr_type(node->data.while_expr.body)->id != TypeTableEntryIdUnreachable) {
|
||||
add_debug_source_node(g, node);
|
||||
LLVMBuildBr(g->builder, body_block);
|
||||
}
|
||||
|
||||
if (contains_break) {
|
||||
LLVMPositionBuilderAtEnd(g->builder, end_block);
|
||||
}
|
||||
} else {
|
||||
// generate a normal while loop
|
||||
|
||||
@ -1755,20 +1828,6 @@ static LLVMAttribute to_llvm_fn_attr(FnAttrId attr_id) {
|
||||
static void do_code_gen(CodeGen *g) {
|
||||
assert(!g->errors.length);
|
||||
|
||||
{
|
||||
LLVMTypeRef param_types[] = {
|
||||
LLVMPointerType(LLVMInt8Type(), 0),
|
||||
LLVMPointerType(LLVMInt8Type(), 0),
|
||||
LLVMIntType(g->pointer_size_bytes * 8),
|
||||
LLVMInt32Type(),
|
||||
LLVMInt1Type(),
|
||||
};
|
||||
LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 5, false);
|
||||
Buf *name = buf_sprintf("llvm.memcpy.p0i8.p0i8.i%d", g->pointer_size_bytes * 8);
|
||||
g->memcpy_fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type);
|
||||
assert(LLVMGetIntrinsicID(g->memcpy_fn_val));
|
||||
}
|
||||
|
||||
// Generate module level variables
|
||||
for (int i = 0; i < g->global_vars.length; i += 1) {
|
||||
VariableTableEntry *var = g->global_vars.at(i);
|
||||
@ -2267,6 +2326,57 @@ static void define_builtin_fns(CodeGen *g) {
|
||||
define_builtin_fns_int(g, g->builtin_types.entry_i16);
|
||||
define_builtin_fns_int(g, g->builtin_types.entry_i32);
|
||||
define_builtin_fns_int(g, g->builtin_types.entry_i64);
|
||||
{
|
||||
BuiltinFnEntry *builtin_fn = allocate<BuiltinFnEntry>(1);
|
||||
buf_init_from_str(&builtin_fn->name, "memcpy");
|
||||
builtin_fn->id = BuiltinFnIdMemcpy;
|
||||
builtin_fn->return_type = g->builtin_types.entry_void;
|
||||
builtin_fn->param_count = 3;
|
||||
builtin_fn->param_types = allocate<TypeTableEntry *>(builtin_fn->param_count);
|
||||
builtin_fn->param_types[0] = nullptr; // manually checked later
|
||||
builtin_fn->param_types[1] = nullptr; // manually checked later
|
||||
builtin_fn->param_types[2] = g->builtin_types.entry_usize;
|
||||
|
||||
LLVMTypeRef param_types[] = {
|
||||
LLVMPointerType(LLVMInt8Type(), 0),
|
||||
LLVMPointerType(LLVMInt8Type(), 0),
|
||||
LLVMIntType(g->pointer_size_bytes * 8),
|
||||
LLVMInt32Type(),
|
||||
LLVMInt1Type(),
|
||||
};
|
||||
LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 5, false);
|
||||
Buf *name = buf_sprintf("llvm.memcpy.p0i8.p0i8.i%d", g->pointer_size_bytes * 8);
|
||||
g->memcpy_fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type);
|
||||
builtin_fn->fn_val = g->memcpy_fn_val;
|
||||
assert(LLVMGetIntrinsicID(g->memcpy_fn_val));
|
||||
|
||||
g->builtin_fn_table.put(&builtin_fn->name, builtin_fn);
|
||||
}
|
||||
{
|
||||
BuiltinFnEntry *builtin_fn = allocate<BuiltinFnEntry>(1);
|
||||
buf_init_from_str(&builtin_fn->name, "memset");
|
||||
builtin_fn->id = BuiltinFnIdMemset;
|
||||
builtin_fn->return_type = g->builtin_types.entry_void;
|
||||
builtin_fn->param_count = 3;
|
||||
builtin_fn->param_types = allocate<TypeTableEntry *>(builtin_fn->param_count);
|
||||
builtin_fn->param_types[0] = nullptr; // manually checked later
|
||||
builtin_fn->param_types[1] = g->builtin_types.entry_u8;
|
||||
builtin_fn->param_types[2] = g->builtin_types.entry_usize;
|
||||
|
||||
LLVMTypeRef param_types[] = {
|
||||
LLVMPointerType(LLVMInt8Type(), 0),
|
||||
LLVMInt8Type(),
|
||||
LLVMIntType(g->pointer_size_bytes * 8),
|
||||
LLVMInt32Type(),
|
||||
LLVMInt1Type(),
|
||||
};
|
||||
LLVMTypeRef fn_type = LLVMFunctionType(LLVMVoidType(), param_types, 5, false);
|
||||
Buf *name = buf_sprintf("llvm.memset.p0i8.i%d", g->pointer_size_bytes * 8);
|
||||
builtin_fn->fn_val = LLVMAddFunction(g->module, buf_ptr(name), fn_type);
|
||||
assert(LLVMGetIntrinsicID(builtin_fn->fn_val));
|
||||
|
||||
g->builtin_fn_table.put(&builtin_fn->name, builtin_fn);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
@ -118,13 +118,7 @@ fn buf_print_u64(out_buf: []u8, x: u64) -> usize {
|
||||
|
||||
const len = buf.len - index;
|
||||
|
||||
// TODO memcpy intrinsic
|
||||
// @memcpy(out_buf, buf, len);
|
||||
var i: usize = 0;
|
||||
while (i < len) {
|
||||
out_buf[i] = buf[index + i];
|
||||
i += 1;
|
||||
}
|
||||
@memcpy(out_buf.ptr, &buf[index], len);
|
||||
|
||||
return len;
|
||||
}
|
||||
|
@ -973,6 +973,24 @@ pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
|
||||
return 0;
|
||||
}
|
||||
)SOURCE", "OK\n");
|
||||
|
||||
add_simple_case("memcpy and memset intrinsics", R"SOURCE(
|
||||
use "std.zig";
|
||||
pub fn main(argc: isize, argv: &&u8, env: &&u8) -> i32 {
|
||||
var foo : [20]u8;
|
||||
var bar : [20]u8;
|
||||
|
||||
@memset(foo.ptr, 'A', foo.len);
|
||||
@memcpy(bar.ptr, foo.ptr, bar.len);
|
||||
|
||||
if (bar[11] != 'A') {
|
||||
print_str("BAD\n");
|
||||
}
|
||||
|
||||
print_str("OK\n");
|
||||
return 0;
|
||||
}
|
||||
)SOURCE", "OK\n");
|
||||
}
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user