fix await used in an expression generating bad LLVM

master
Andrew Kelley 2019-09-06 16:17:39 -04:00
parent 9423d382fb
commit 9ca8d9e21a
No known key found for this signature in database
GPG Key ID: 7C5F548F728501A9
4 changed files with 113 additions and 44 deletions

View File

@ -4232,31 +4232,40 @@ static Error analyze_callee_async(CodeGen *g, ZigFn *fn, ZigFn *callee, AstNode
{
if (modifier == CallModifierNoAsync)
return ErrorNone;
if (callee->type_entry->data.fn.fn_type_id.cc != CallingConventionUnspecified)
return ErrorNone;
if (callee->anal_state == FnAnalStateReady) {
analyze_fn_body(g, callee);
if (callee->anal_state == FnAnalStateInvalid) {
return ErrorSemanticAnalyzeFail;
}
bool callee_is_async = false;
switch (callee->type_entry->data.fn.fn_type_id.cc) {
case CallingConventionUnspecified:
break;
case CallingConventionAsync:
callee_is_async = true;
break;
default:
return ErrorNone;
}
bool callee_is_async;
if (callee->anal_state == FnAnalStateComplete) {
analyze_fn_async(g, callee, true);
if (callee->anal_state == FnAnalStateInvalid) {
return ErrorSemanticAnalyzeFail;
if (!callee_is_async) {
if (callee->anal_state == FnAnalStateReady) {
analyze_fn_body(g, callee);
if (callee->anal_state == FnAnalStateInvalid) {
return ErrorSemanticAnalyzeFail;
}
}
callee_is_async = fn_is_async(callee);
} else {
// If it's already been determined, use that value. Otherwise
// assume non-async, emit an error later if it turned out to be async.
if (callee->inferred_async_node == nullptr ||
callee->inferred_async_node == inferred_async_checking)
{
callee->assumed_non_async = call_node;
callee_is_async = false;
if (callee->anal_state == FnAnalStateComplete) {
analyze_fn_async(g, callee, true);
if (callee->anal_state == FnAnalStateInvalid) {
return ErrorSemanticAnalyzeFail;
}
callee_is_async = fn_is_async(callee);
} else {
callee_is_async = callee->inferred_async_node != inferred_async_none;
// If it's already been determined, use that value. Otherwise
// assume non-async, emit an error later if it turned out to be async.
if (callee->inferred_async_node == nullptr ||
callee->inferred_async_node == inferred_async_checking)
{
callee->assumed_non_async = call_node;
callee_is_async = false;
} else {
callee_is_async = callee->inferred_async_node != inferred_async_none;
}
}
}
if (callee_is_async) {
@ -4333,6 +4342,8 @@ static void analyze_fn_async(CodeGen *g, ZigFn *fn, bool resolve_frame) {
}
for (size_t i = 0; i < fn->await_list.length; i += 1) {
IrInstructionAwaitGen *await = fn->await_list.at(i);
// TODO If this is a noasync await, it doesn't count
// https://github.com/ziglang/zig/issues/3157
switch (analyze_callee_async(g, fn, await->target_fn, await->base.source_node, must_not_be_async,
CallModifierNone))
{
@ -5771,15 +5782,39 @@ static Error resolve_async_frame(CodeGen *g, ZigType *frame_type) {
if (!fn_is_async(callee))
continue;
IrInstructionAllocaGen *alloca_gen = allocate<IrInstructionAllocaGen>(1);
alloca_gen->base.id = IrInstructionIdAllocaGen;
alloca_gen->base.source_node = call->base.source_node;
alloca_gen->base.scope = call->base.scope;
alloca_gen->base.value.type = get_pointer_to_type(g, callee_frame_type, false);
alloca_gen->base.ref_count = 1;
alloca_gen->name_hint = "";
fn->alloca_gen_list.append(alloca_gen);
call->frame_result_loc = &alloca_gen->base;
call->frame_result_loc = ir_create_alloca(g, call->base.scope, call->base.source_node, fn,
callee_frame_type, "");
}
// Since this frame is async, an await might represent a suspend point, and
// therefore need to spill.
for (size_t i = 0; i < fn->await_list.length; i += 1) {
IrInstructionAwaitGen *await = fn->await_list.at(i);
// TODO If this is a noasync await, it doesn't need to spill
// https://github.com/ziglang/zig/issues/3157
if (await->result_loc != nullptr) {
// If there's a result location, that is the spill
continue;
}
if (!type_has_bits(await->base.value.type))
continue;
if (await->base.value.special != ConstValSpecialRuntime)
continue;
if (await->base.ref_count == 0)
continue;
if (await->target_fn != nullptr) {
// we might not need to suspend
analyze_fn_async(g, await->target_fn, false);
if (await->target_fn->anal_state == FnAnalStateInvalid) {
frame_type->data.frame.locals_struct = g->builtin_types.entry_invalid;
return ErrorSemanticAnalyzeFail;
}
if (!fn_is_async(await->target_fn)) {
// This await does not represent a suspend point. No spill needed.
continue;
}
}
await->result_loc = ir_create_alloca(g, await->base.scope, await->base.source_node, fn,
await->base.value.type, "");
}
FnTypeId *fn_type_id = &fn_type->data.fn.fn_type_id;
ZigType *ptr_return_type = get_pointer_to_type(g, fn_type_id->return_type, false);
@ -8505,3 +8540,18 @@ void src_assert(bool ok, AstNode *source_node) {
const char *msg = "assertion failed. This is a bug in the Zig compiler.";
stage2_panic(msg, strlen(msg));
}
IrInstruction *ir_create_alloca(CodeGen *g, Scope *scope, AstNode *source_node, ZigFn *fn,
ZigType *var_type, const char *name_hint)
{
IrInstructionAllocaGen *alloca_gen = allocate<IrInstructionAllocaGen>(1);
alloca_gen->base.id = IrInstructionIdAllocaGen;
alloca_gen->base.source_node = source_node;
alloca_gen->base.scope = scope;
alloca_gen->base.value.type = get_pointer_to_type(g, var_type, false);
alloca_gen->base.ref_count = 1;
alloca_gen->name_hint = name_hint;
fn->alloca_gen_list.append(alloca_gen);
return &alloca_gen->base;
}

View File

@ -258,4 +258,8 @@ ZigType *resolve_struct_field_type(CodeGen *g, TypeStructField *struct_field);
void add_async_error_notes(CodeGen *g, ErrorMsg *msg, ZigFn *fn);
IrInstruction *ir_create_alloca(CodeGen *g, Scope *scope, AstNode *source_node, ZigFn *fn,
ZigType *var_type, const char *name_hint);
#endif

View File

@ -1661,6 +1661,14 @@ static LLVMValueRef ir_llvm_value(CodeGen *g, IrInstruction *instruction) {
if (!type_has_bits(instruction->value.type))
return nullptr;
if (!instruction->llvm_value) {
if (instruction->id == IrInstructionIdAwaitGen) {
IrInstructionAwaitGen *await = reinterpret_cast<IrInstructionAwaitGen*>(instruction);
if (await->result_loc != nullptr) {
instruction->llvm_value = get_handle_value(g, ir_llvm_value(g, await->result_loc),
await->result_loc->value.type->data.pointer.child_type, await->result_loc->value.type);
return instruction->llvm_value;
}
}
src_assert(instruction->value.special != ConstValSpecialRuntime, instruction->source_node);
assert(instruction->value.type);
render_const_val(g, &instruction->value, "");
@ -5645,7 +5653,6 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst
// At this point resuming the function will continue from resume_bb.
// This code is as if it is running inside the suspend block.
// supply the awaiter return pointer
if (type_has_bits(result_type)) {
LLVMValueRef awaiter_ret_ptr_ptr = LLVMBuildStructGEP(g->builder, target_frame_ptr, frame_ret_start + 1, "");
@ -5703,9 +5710,8 @@ static LLVMValueRef ir_render_await(CodeGen *g, IrExecutable *executable, IrInst
LLVMBuildBr(g->builder, end_bb);
LLVMPositionBuilderAtEnd(g->builder, end_bb);
if (type_has_bits(result_type) && result_loc != nullptr) {
return get_handle_value(g, result_loc, result_type, ptr_result_type);
}
// Rely on the spill for the llvm_value to be populated.
// See the implementation of ir_llvm_value.
return nullptr;
}
@ -7153,15 +7159,8 @@ static void do_code_gen(CodeGen *g) {
if (call->frame_result_loc != nullptr)
continue;
ZigType *callee_frame_type = get_fn_frame_type(g, call->fn_entry);
IrInstructionAllocaGen *alloca_gen = allocate<IrInstructionAllocaGen>(1);
alloca_gen->base.id = IrInstructionIdAllocaGen;
alloca_gen->base.source_node = call->base.source_node;
alloca_gen->base.scope = call->base.scope;
alloca_gen->base.value.type = get_pointer_to_type(g, callee_frame_type, false);
alloca_gen->base.ref_count = 1;
alloca_gen->name_hint = "";
fn_table_entry->alloca_gen_list.append(alloca_gen);
call->frame_result_loc = &alloca_gen->base;
call->frame_result_loc = ir_create_alloca(g, call->base.scope, call->base.source_node,
fn_table_entry, callee_frame_type, "");
}
// allocate temporary stack data
for (size_t alloca_i = 0; alloca_i < fn_table_entry->alloca_gen_list.length; alloca_i += 1) {

View File

@ -1108,3 +1108,19 @@ test "noasync function call" {
};
S.doTheTest();
}
test "await used in expression and awaiting fn with no suspend but async calling convention" {
const S = struct {
fn atest() void {
var f1 = async add(1, 2);
var f2 = async add(3, 4);
const sum = (await f1) + (await f2);
expect(sum == 10);
}
async fn add(a: i32, b: i32) i32 {
return a + b;
}
};
_ = async S.atest();
}