Add mem.timingSafeEql() for constant-time array comparison

This is a trivial implementation that just does a or[xor] loop.

However, this pattern is used by virtually all crypto libraries and
in practice, even without assembly barriers, LLVM never turns it into
code with conditional jumps, even if one of the parameters is constant.

This has been verified to still be the case with LLVM 11.0.0.
master
Frank Denis 2020-08-23 01:36:37 +02:00
parent 03ae77b8b0
commit bd07154242
7 changed files with 97 additions and 25 deletions

View File

@ -130,6 +130,8 @@ pub const nacl = struct {
pub const SealedBox = salsa20.SealedBox; pub const SealedBox = salsa20.SealedBox;
}; };
pub const utils = @import("crypto/utils.zig");
const std = @import("std.zig"); const std = @import("std.zig");
pub const randomBytes = std.os.getrandom; pub const randomBytes = std.os.getrandom;

View File

@ -11,6 +11,7 @@ const math = std.math;
const mem = std.mem; const mem = std.mem;
const debug = std.debug; const debug = std.debug;
const testing = std.testing; const testing = std.testing;
const utils = std.crypto.utils;
const salt_length: usize = 16; const salt_length: usize = 16;
const salt_str_length: usize = 22; const salt_str_length: usize = 22;
@ -226,7 +227,7 @@ fn strHashInternal(password: []const u8, rounds_log: u6, salt: [salt_length]u8)
state.expand0(passwordZ); state.expand0(passwordZ);
state.expand0(salt[0..]); state.expand0(salt[0..]);
} }
mem.secureZero(u8, &password_buf); utils.secureZero(u8, &password_buf);
var cdata = [6]u32{ 0x4f727068, 0x65616e42, 0x65686f6c, 0x64657253, 0x63727944, 0x6f756274 }; // "OrpheanBeholderScryDoubt" var cdata = [6]u32{ 0x4f727068, 0x65616e42, 0x65686f6c, 0x64657253, 0x63727944, 0x6f756274 }; // "OrpheanBeholderScryDoubt"
k = 0; k = 0;

View File

@ -10,6 +10,7 @@ const std = @import("../std.zig");
const assert = std.debug.assert; const assert = std.debug.assert;
const math = std.math; const math = std.math;
const mem = std.mem; const mem = std.mem;
const utils = std.crypto.utils;
/// GHASH is a universal hash function that features multiplication /// GHASH is a universal hash function that features multiplication
/// by a fixed parameter within a Galois field. /// by a fixed parameter within a Galois field.
@ -305,7 +306,7 @@ pub const Ghash = struct {
mem.writeIntBig(u64, out[0..8], st.y1); mem.writeIntBig(u64, out[0..8], st.y1);
mem.writeIntBig(u64, out[8..16], st.y0); mem.writeIntBig(u64, out[8..16], st.y0);
mem.secureZero(u8, @ptrCast([*]u8, st)[0..@sizeOf(Ghash)]); utils.secureZero(u8, @ptrCast([*]u8, st)[0..@sizeOf(Ghash)]);
} }
pub fn create(out: *[mac_length]u8, msg: []const u8, key: *const [key_length]u8) void { pub fn create(out: *[mac_length]u8, msg: []const u8, key: *const [key_length]u8) void {

View File

@ -4,6 +4,7 @@
// The MIT license requires this copyright notice to be included in all copies // The MIT license requires this copyright notice to be included in all copies
// and substantial portions of the software. // and substantial portions of the software.
const std = @import("../std.zig"); const std = @import("../std.zig");
const utils = std.crypto.utils;
const mem = std.mem; const mem = std.mem;
pub const Poly1305 = struct { pub const Poly1305 = struct {
@ -195,7 +196,7 @@ pub const Poly1305 = struct {
mem.writeIntLittle(u64, out[0..8], st.h[0]); mem.writeIntLittle(u64, out[0..8], st.h[0]);
mem.writeIntLittle(u64, out[8..16], st.h[1]); mem.writeIntLittle(u64, out[8..16], st.h[1]);
std.mem.secureZero(u8, @ptrCast([*]u8, st)[0..@sizeOf(Poly1305)]); utils.secureZero(u8, @ptrCast([*]u8, st)[0..@sizeOf(Poly1305)]);
} }
pub fn create(out: *[mac_length]u8, msg: []const u8, key: *const [key_length]u8) void { pub fn create(out: *[mac_length]u8, msg: []const u8, key: *const [key_length]u8) void {

View File

@ -9,6 +9,7 @@ const crypto = std.crypto;
const debug = std.debug; const debug = std.debug;
const math = std.math; const math = std.math;
const mem = std.mem; const mem = std.mem;
const utils = std.crypto.utils;
const Vector = std.meta.Vector; const Vector = std.meta.Vector;
const Poly1305 = crypto.onetimeauth.Poly1305; const Poly1305 = crypto.onetimeauth.Poly1305;
@ -414,7 +415,7 @@ pub const XSalsa20Poly1305 = struct {
acc |= computedTag[i] ^ tag[i]; acc |= computedTag[i] ^ tag[i];
} }
if (acc != 0) { if (acc != 0) {
mem.secureZero(u8, &computedTag); utils.secureZero(u8, &computedTag);
return error.AuthenticationFailed; return error.AuthenticationFailed;
} }
mem.copy(u8, m[0..mlen0], block0[32..][0..mlen0]); mem.copy(u8, m[0..mlen0], block0[32..][0..mlen0]);
@ -532,7 +533,7 @@ pub const SealedBox = struct {
const nonce = createNonce(ekp.public_key, public_key); const nonce = createNonce(ekp.public_key, public_key);
mem.copy(u8, c[0..public_length], ekp.public_key[0..]); mem.copy(u8, c[0..public_length], ekp.public_key[0..]);
try Box.seal(c[Box.public_length..], m, nonce, public_key, ekp.secret_key); try Box.seal(c[Box.public_length..], m, nonce, public_key, ekp.secret_key);
mem.secureZero(u8, ekp.secret_key[0..]); utils.secureZero(u8, ekp.secret_key[0..]);
} }
/// Decrypt a message using a key pair. /// Decrypt a message using a key pair.

86
lib/std/crypto/utils.zig Normal file
View File

@ -0,0 +1,86 @@
const std = @import("../std.zig");
const mem = std.mem;
const testing = std.testing;
/// Compares two arrays in constant time (for a given length) and returns whether they are equal.
/// This function was designed to compare short cryptographic secrets (MACs, signatures).
/// For all other applications, use mem.eql() instead.
pub fn timingSafeEql(comptime T: type, a: T, b: T) bool {
switch (@typeInfo(T)) {
.Array => |info| {
const C = info.child;
if (@typeInfo(C) != .Int) {
@compileError("Elements to be compared must be integers");
}
var acc = @as(C, 0);
for (a) |x, i| {
acc |= x ^ b[i];
}
comptime const s = @typeInfo(C).Int.bits;
comptime const Cu = std.meta.Int(.unsigned, s);
comptime const Cext = std.meta.Int(.unsigned, s + 1);
return @bitCast(bool, @truncate(u1, (@as(Cext, @bitCast(Cu, acc)) -% 1) >> s));
},
.Vector => |info| {
const C = info.child;
if (@typeInfo(C) != .Int) {
@compileError("Elements to be compared must be integers");
}
const z = a ^ b;
var acc = @as(C, 0);
var i: usize = 0;
while (i < info.len) : (i += 1) {
acc |= z[i];
}
comptime const s = @typeInfo(C).Int.bits;
comptime const Cu = std.meta.Int(.unsigned, s);
comptime const Cext = std.meta.Int(.unsigned, s + 1);
return @bitCast(bool, @truncate(u1, (@as(Cext, @bitCast(Cu, acc)) -% 1) >> s));
},
else => {
@compileError("Only arrays and vectors can be compared");
},
}
}
/// Sets a slice to zeroes.
/// Prevents the store from being optimized out.
pub fn secureZero(comptime T: type, s: []T) void {
// NOTE: We do not use a volatile slice cast here since LLVM cannot
// see that it can be replaced by a memset.
const ptr = @ptrCast([*]volatile u8, s.ptr);
const length = s.len * @sizeOf(T);
@memset(ptr, 0, length);
}
test "crypto.utils.timingSafeEql" {
var a: [100]u8 = undefined;
var b: [100]u8 = undefined;
try std.crypto.randomBytes(a[0..]);
try std.crypto.randomBytes(b[0..]);
testing.expect(!timingSafeEql([100]u8, a, b));
mem.copy(u8, a[0..], b[0..]);
testing.expect(timingSafeEql([100]u8, a, b));
}
test "crypto.utils.timingSafeEql (vectors)" {
var a: [100]u8 = undefined;
var b: [100]u8 = undefined;
try std.crypto.randomBytes(a[0..]);
try std.crypto.randomBytes(b[0..]);
const v1: std.meta.Vector(100, u8) = a;
const v2: std.meta.Vector(100, u8) = b;
testing.expect(!timingSafeEql(std.meta.Vector(100, u8), v1, v2));
const v3: std.meta.Vector(100, u8) = a;
testing.expect(timingSafeEql(std.meta.Vector(100, u8), v1, v3));
}
test "crypto.utils.secureZero" {
var a = [_]u8{0xfe} ** 8;
var b = [_]u8{0xfe} ** 8;
mem.set(u8, a[0..], 0);
secureZero(u8, b[0..]);
testing.expectEqualSlices(u8, a[0..], b[0..]);
}

View File

@ -342,26 +342,6 @@ test "mem.zeroes" {
testing.expectEqual(@as(u8, 0), c.a); testing.expectEqual(@as(u8, 0), c.a);
} }
/// Sets a slice to zeroes.
/// Prevents the store from being optimized out.
pub fn secureZero(comptime T: type, s: []T) void {
// NOTE: We do not use a volatile slice cast here since LLVM cannot
// see that it can be replaced by a memset.
const ptr = @ptrCast([*]volatile u8, s.ptr);
const length = s.len * @sizeOf(T);
@memset(ptr, 0, length);
}
test "mem.secureZero" {
var a = [_]u8{0xfe} ** 8;
var b = [_]u8{0xfe} ** 8;
set(u8, a[0..], 0);
secureZero(u8, b[0..]);
testing.expectEqualSlices(u8, a[0..], b[0..]);
}
/// Initializes all fields of the struct with their default value, or zero values if no default value is present. /// Initializes all fields of the struct with their default value, or zero values if no default value is present.
/// If the field is present in the provided initial values, it will have that value instead. /// If the field is present in the provided initial values, it will have that value instead.
/// Structs are initialized recursively. /// Structs are initialized recursively.