From 1dc6751721a2fe9990ea8fab4eadc95a29f53304 Mon Sep 17 00:00:00 2001 From: Andrew Kelley Date: Thu, 4 Apr 2019 22:07:15 -0400 Subject: [PATCH] fix NaN comparing equal to itself This was broken both in comptime code and in runtime code. closes #1174 --- CMakeLists.txt | 1 + src/bigfloat.cpp | 4 ++++ src/bigfloat.hpp | 1 + src/codegen.cpp | 2 +- src/ir.cpp | 25 +++++++++++++++++++++++++ test/stage1/behavior/math.zig | 22 ++++++++++++++++++++++ 6 files changed, 54 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c58029100..e04dc79b1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -302,6 +302,7 @@ set(EMBEDDED_SOFTFLOAT_SOURCES "${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/f16_add.c" "${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/f16_div.c" "${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/f16_eq.c" + "${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/f16_isSignalingNaN.c" "${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/f16_lt.c" "${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/f16_mul.c" "${CMAKE_SOURCE_DIR}/deps/SoftFloat-3e/source/f16_rem.c" diff --git a/src/bigfloat.cpp b/src/bigfloat.cpp index d746f1b68..a2a3a3b69 100644 --- a/src/bigfloat.cpp +++ b/src/bigfloat.cpp @@ -190,3 +190,7 @@ bool bigfloat_has_fraction(const BigFloat *bigfloat) { void bigfloat_sqrt(BigFloat *dest, const BigFloat *op) { f128M_sqrt(&op->value, &dest->value); } + +bool bigfloat_is_nan(const BigFloat *op) { + return f128M_isSignalingNaN(&op->value); +} diff --git a/src/bigfloat.hpp b/src/bigfloat.hpp index 176e860ac..3ed6624fd 100644 --- a/src/bigfloat.hpp +++ b/src/bigfloat.hpp @@ -48,6 +48,7 @@ void bigfloat_sqrt(BigFloat *dest, const BigFloat *op); void bigfloat_append_buf(Buf *buf, const BigFloat *op); Cmp bigfloat_cmp(const BigFloat *op1, const BigFloat *op2); +bool bigfloat_is_nan(const BigFloat *op); // convenience functions Cmp bigfloat_cmp_zero(const BigFloat *bigfloat); diff --git a/src/codegen.cpp b/src/codegen.cpp index 3b3a97dbf..568344fc0 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -1852,7 +1852,7 @@ static LLVMRealPredicate cmp_op_to_real_predicate(IrBinOp cmp_op) { case IrBinOpCmpEq: return LLVMRealOEQ; case IrBinOpCmpNotEq: - return LLVMRealONE; + return LLVMRealUNE; case IrBinOpCmpLessThan: return LLVMRealOLT; case IrBinOpCmpGreaterThan: diff --git a/src/ir.cpp b/src/ir.cpp index ea7a27f31..3f02d18e8 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -17,6 +17,7 @@ #include "util.hpp" #include +#include struct IrExecContext { ZigList mem_slot_list; @@ -8242,6 +8243,27 @@ static void float_init_float(ConstExprValue *dest_val, ConstExprValue *src_val) } } +static bool float_is_nan(ConstExprValue *op) { + if (op->type->id == ZigTypeIdComptimeFloat) { + return bigfloat_is_nan(&op->data.x_bigfloat); + } else if (op->type->id == ZigTypeIdFloat) { + switch (op->type->data.floating.bit_count) { + case 16: + return f16_isSignalingNaN(op->data.x_f16); + case 32: + return isnan(op->data.x_f32); + case 64: + return isnan(op->data.x_f64); + case 128: + return f128M_isSignalingNaN(&op->data.x_f128); + default: + zig_unreachable(); + } + } else { + zig_unreachable(); + } +} + static Cmp float_cmp(ConstExprValue *op1, ConstExprValue *op2) { assert(op1->type == op2->type); if (op1->type->id == ZigTypeIdComptimeFloat) { @@ -12378,6 +12400,9 @@ static IrInstruction *ir_analyze_bin_op_cmp(IrAnalyze *ira, IrInstructionBinOp * return ira->codegen->invalid_instruction; if (resolved_type->id == ZigTypeIdComptimeFloat || resolved_type->id == ZigTypeIdFloat) { + if (float_is_nan(op1_val) || float_is_nan(op2_val)) { + return ir_const_bool(ira, &bin_op_instruction->base, op_id == IrBinOpCmpNotEq); + } Cmp cmp_result = float_cmp(op1_val, op2_val); bool answer = resolve_cmp_op_id(op_id, cmp_result); return ir_const_bool(ira, &bin_op_instruction->base, answer); diff --git a/test/stage1/behavior/math.zig b/test/stage1/behavior/math.zig index 6dd495cfd..23dc6d1fe 100644 --- a/test/stage1/behavior/math.zig +++ b/test/stage1/behavior/math.zig @@ -610,3 +610,25 @@ test "vector integer addition" { S.doTheTest(); comptime S.doTheTest(); } + +test "NaN comparison" { + testNanEqNan(f16); + testNanEqNan(f32); + testNanEqNan(f64); + testNanEqNan(f128); + comptime testNanEqNan(f16); + comptime testNanEqNan(f32); + comptime testNanEqNan(f64); + comptime testNanEqNan(f128); +} + +fn testNanEqNan(comptime F: type) void { + var nan1 = std.math.nan(F); + var nan2 = std.math.nan(F); + expect(nan1 != nan2); + expect(!(nan1 == nan2)); + expect(!(nan1 > nan2)); + expect(!(nan1 >= nan2)); + expect(!(nan1 < nan2)); + expect(!(nan1 <= nan2)); +}