diff --git a/src/analyze.cpp b/src/analyze.cpp index f574c8d4b..756f12251 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -214,9 +214,12 @@ TypeTableEntry *get_maybe_type(CodeGen *g, TypeTableEntry *child_type) { buf_resize(&entry->name, 0); buf_appendf(&entry->name, "?%s", buf_ptr(&child_type->name)); - if (child_type->id == TypeTableEntryIdPointer) { + if (child_type->id == TypeTableEntryIdPointer || + child_type->id == TypeTableEntryIdFn) + { // this is an optimization but also is necessary for calling C // functions where all pointers are maybe pointers + // function types are technically pointers entry->size_in_bits = child_type->size_in_bits; entry->align_in_bits = child_type->align_in_bits; entry->type_ref = child_type->type_ref; @@ -5384,7 +5387,8 @@ bool handle_is_ptr(TypeTableEntry *type_entry) { case TypeTableEntryIdEnum: return type_entry->data.enumeration.gen_field_count != 0; case TypeTableEntryIdMaybe: - return type_entry->data.maybe.child_type->id != TypeTableEntryIdPointer; + return type_entry->data.maybe.child_type->id != TypeTableEntryIdPointer && + type_entry->data.maybe.child_type->id != TypeTableEntryIdFn; case TypeTableEntryIdTypeDecl: return handle_is_ptr(type_entry->data.type_decl.canonical_type); } diff --git a/src/codegen.cpp b/src/codegen.cpp index a5a91b55d..01f3fe60c 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -398,7 +398,9 @@ static LLVMValueRef gen_cast_expr(CodeGen *g, AstNode *node) { TypeTableEntry *child_type = wanted_type->data.maybe.child_type; - if (child_type->id == TypeTableEntryIdPointer) { + if (child_type->id == TypeTableEntryIdPointer || + child_type->id == TypeTableEntryIdFn) + { return expr_val; } else { add_debug_source_node(g, node); @@ -1274,7 +1276,9 @@ static LLVMValueRef gen_unwrap_maybe(CodeGen *g, AstNode *node, LLVMValueRef may TypeTableEntry *type_entry = get_expr_type(node); assert(type_entry->id == TypeTableEntryIdMaybe); TypeTableEntry *child_type = type_entry->data.maybe.child_type; - if (child_type->id == TypeTableEntryIdPointer) { + if (child_type->id == TypeTableEntryIdPointer || + child_type->id == TypeTableEntryIdFn) + { return maybe_struct_ref; } else { add_debug_source_node(g, node); @@ -1301,7 +1305,9 @@ static LLVMValueRef gen_unwrap_maybe_expr(CodeGen *g, AstNode *node) { TypeTableEntry *child_type = maybe_type->data.maybe.child_type; LLVMValueRef cond_value; - if (child_type->id == TypeTableEntryIdPointer) { + if (child_type->id == TypeTableEntryIdPointer || + child_type->id == TypeTableEntryIdFn) + { cond_value = LLVMBuildICmp(g->builder, LLVMIntNE, maybe_struct_ref, LLVMConstNull(child_type->type_ref), ""); } else { @@ -1651,7 +1657,9 @@ static LLVMValueRef gen_if_var_expr(CodeGen *g, AstNode *node) { assert(expr_type->id == TypeTableEntryIdMaybe); TypeTableEntry *child_type = expr_type->data.maybe.child_type; LLVMValueRef cond_value; - if (child_type->id == TypeTableEntryIdPointer) { + if (child_type->id == TypeTableEntryIdPointer || + child_type->id == TypeTableEntryIdFn) + { cond_value = LLVMBuildICmp(g->builder, LLVMIntNE, init_val, LLVMConstNull(child_type->type_ref), ""); } else { add_debug_source_node(g, node); @@ -2377,7 +2385,9 @@ static LLVMValueRef gen_const_val(CodeGen *g, TypeTableEntry *type_entry, ConstE case TypeTableEntryIdMaybe: { TypeTableEntry *child_type = type_entry->data.maybe.child_type; - if (child_type->id == TypeTableEntryIdPointer) { + if (child_type->id == TypeTableEntryIdPointer || + child_type->id == TypeTableEntryIdFn) + { if (const_val->data.x_maybe) { return gen_const_val(g, child_type, const_val->data.x_maybe); } else { diff --git a/test/run_tests.cpp b/test/run_tests.cpp index c873ae1b0..7cc8b0c56 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -1447,6 +1447,42 @@ pub fn main(args: [][]u8) -> %void { f(false); } )SOURCE", "a\nb\n"); + + + add_simple_case("expose function pointer to C land", R"SOURCE( +#link("c") +export executable "test"; + +c_import { + @c_include("stdlib.h"); +} + +export fn compare_fn(a: ?&const c_void, b: ?&const c_void) -> c_int { + const a_int = (&i32)(a ?? unreachable{}); + const b_int = (&i32)(b ?? unreachable{}); + if (*a_int < *b_int) { + -1 + } else if (*a_int > *b_int) { + 1 + } else { + 0 + } +} + +export fn main(args: c_int, argv: &&u8) -> c_int { + var array = []i32 { 1, 7, 3, 2, 0, 9, 4, 8, 6, 5 }; + + qsort((&c_void)(array.ptr), c_ulong(array.len), @sizeof(i32), compare_fn); + + for (item, array, i) { + if (item != i) { + abort(); + } + } + + return 0; +} + )SOURCE", ""); }