diff --git a/src/analyze.cpp b/src/analyze.cpp index eb3fb4cd6..4c58e27fc 100644 --- a/src/analyze.cpp +++ b/src/analyze.cpp @@ -3014,20 +3014,52 @@ static TypeTableEntry *analyze_bool_bin_op_expr(CodeGen *g, ImportTableEntry *im TypeTableEntry *resolved_type = resolve_peer_type_compatibility(g, import, context, node, op_nodes, op_types, 2); - bool type_can_gt_lt_cmp = (resolved_type->id == TypeTableEntryIdNumLitFloat || - resolved_type->id == TypeTableEntryIdNumLitInt || - resolved_type->id == TypeTableEntryIdFloat || - resolved_type->id == TypeTableEntryIdInt); + bool is_equality_cmp = (bin_op_type == BinOpTypeCmpEq || bin_op_type == BinOpTypeCmpNotEq); - if (resolved_type->id == TypeTableEntryIdInvalid) { - return g->builtin_types.entry_invalid; - } else if (bin_op_type != BinOpTypeCmpEq && - bin_op_type != BinOpTypeCmpNotEq && - !type_can_gt_lt_cmp) - { - add_node_error(g, node, - buf_sprintf("operator not allowed for type '%s'", buf_ptr(&resolved_type->name))); - return g->builtin_types.entry_invalid; + switch (resolved_type->id) { + case TypeTableEntryIdInvalid: + return g->builtin_types.entry_invalid; + + case TypeTableEntryIdNumLitFloat: + case TypeTableEntryIdNumLitInt: + case TypeTableEntryIdInt: + case TypeTableEntryIdFloat: + break; + + case TypeTableEntryIdBool: + case TypeTableEntryIdMetaType: + case TypeTableEntryIdVoid: + case TypeTableEntryIdPointer: + case TypeTableEntryIdPureError: + case TypeTableEntryIdFn: + case TypeTableEntryIdTypeDecl: + case TypeTableEntryIdNamespace: + case TypeTableEntryIdGenericFn: + if (!is_equality_cmp) { + add_node_error(g, node, + buf_sprintf("operator not allowed for type '%s'", buf_ptr(&resolved_type->name))); + return g->builtin_types.entry_invalid; + } + break; + + case TypeTableEntryIdEnum: + if (!is_equality_cmp || resolved_type->data.enumeration.gen_field_count != 0) { + add_node_error(g, node, + buf_sprintf("operator not allowed for type '%s'", buf_ptr(&resolved_type->name))); + return g->builtin_types.entry_invalid; + } + break; + + case TypeTableEntryIdUnreachable: + case TypeTableEntryIdArray: + case TypeTableEntryIdStruct: + case TypeTableEntryIdUndefLit: + case TypeTableEntryIdMaybe: + case TypeTableEntryIdErrorUnion: + case TypeTableEntryIdUnion: + add_node_error(g, node, + buf_sprintf("operator not allowed for type '%s'", buf_ptr(&resolved_type->name))); + return g->builtin_types.entry_invalid; } ConstExprValue *op1_val = &get_resolved_expr(*op1)->const_val; diff --git a/test/run_tests.cpp b/test/run_tests.cpp index d8f3c9a41..09b86dc7a 100644 --- a/test/run_tests.cpp +++ b/test/run_tests.cpp @@ -1218,6 +1218,21 @@ fn test_a_thing() { bad_fn_call(); } )SOURCE", 1, ".tmp_source.zig:6:5: error: use of undeclared identifier 'bad_fn_call'"); + + add_compile_fail_case("illegal comparison of types", R"SOURCE( +fn bad_eql_1(a: []u8, b: []u8) -> bool { + a == b +} +enum EnumWithData { + One, + Two: i32, +} +fn bad_eql_2(a: EnumWithData, b: EnumWithData) -> bool { + a == b +} + )SOURCE", 2, + ".tmp_source.zig:3:7: error: operator not allowed for type '[]u8'", + ".tmp_source.zig:10:7: error: operator not allowed for type 'EnumWithData'"); } //////////////////////////////////////////////////////////////////////////////