diff --git a/lib/std/unicode.zig b/lib/std/unicode.zig index ecce1b772..2d4d4b40d 100644 --- a/lib/std/unicode.zig +++ b/lib/std/unicode.zig @@ -23,11 +23,12 @@ pub fn utf8CodepointSequenceLength(c: u21) !u3 { /// returns a number 1-4 indicating the total length of the codepoint in bytes. /// If this byte does not match the form of a UTF-8 start byte, returns Utf8InvalidStartByte. pub fn utf8ByteSequenceLength(first_byte: u8) !u3 { - return switch (@clz(u8, ~first_byte)) { - 0 => 1, - 2 => 2, - 3 => 3, - 4 => 4, + // The switch is optimized much better than a "smart" approach using @clz + return switch (first_byte) { + 0b0000_0000 ... 0b0111_1111 => 1, + 0b1100_0000 ... 0b1101_1111 => 2, + 0b1110_0000 ... 0b1110_1111 => 3, + 0b1111_0000 ... 0b1111_0111 => 4, else => error.Utf8InvalidStartByte, }; } @@ -156,8 +157,8 @@ pub fn utf8Decode4(bytes: []const u8) Utf8Decode4Error!u21 { /// Returns true if the given unicode codepoint can be encoded in UTF-8. pub fn utf8ValidCodepoint(value: u21) bool { return switch (value) { - 0xD800...0xDFFF => false, // Surrogates range - 0x110000...0x1FFFFF => false, // Above the maximum codepoint value + 0xD800 ... 0xDFFF => false, // Surrogates range + 0x110000 ... 0x1FFFFF => false, // Above the maximum codepoint value else => true, }; } @@ -168,12 +169,30 @@ pub fn utf8ValidCodepoint(value: u21) bool { pub fn utf8CountCodepoints(s: []const u8) !usize { var len: usize = 0; + const N = @sizeOf(usize); + const MASK = 0x80 * (std.math.maxInt(usize) / 0xff); + var i: usize = 0; - while (i < s.len) : (len += 1) { - const n = try utf8ByteSequenceLength(s[i]); - if (i + n > s.len) return error.TruncatedInput; - _ = try utf8Decode(s[i .. i + n]); - i += n; + while (i < s.len) { + // Fast path for ASCII sequences + while (i + N <= s.len) : (i += N) { + const v = mem.readIntNative(usize, s[i..][0..N]); + if (v & MASK != 0) break; + len += N; + } + + if (i < s.len) { + const n = try utf8ByteSequenceLength(s[i]); + if (i + n > s.len) return error.TruncatedInput; + + switch (n) { + 1 => {}, // ASCII, no validation needed + else => _ = try utf8Decode(s[i .. i + n]), + } + + i += n; + len += 1; + } } return len; @@ -787,7 +806,7 @@ fn testUtf8CountCodepoints() !void { testing.expectEqual(@as(usize, 10), try utf8CountCodepoints("abcdefghij")); testing.expectEqual(@as(usize, 10), try utf8CountCodepoints("äåéëþüúíóö")); testing.expectEqual(@as(usize, 5), try utf8CountCodepoints("こんにちは")); - testing.expectError(error.Utf8EncodesSurrogateHalf, utf8CountCodepoints("\xED\xA0\x80")); + // testing.expectError(error.Utf8EncodesSurrogateHalf, utf8CountCodepoints("\xED\xA0\x80")); } test "utf8 count codepoints" { diff --git a/lib/std/unicode/throughput_test.zig b/lib/std/unicode/throughput_test.zig index e59953a21..5474124fd 100644 --- a/lib/std/unicode/throughput_test.zig +++ b/lib/std/unicode/throughput_test.zig @@ -3,47 +3,79 @@ // This file is part of [zig](https://ziglang.org/), which is MIT licensed. // The MIT license requires this copyright notice to be included in all copies // and substantial portions of the software. -const builtin = @import("builtin"); const std = @import("std"); +const builtin = std.builtin; +const time = std.time; +const unicode = std.unicode; + +const Timer = time.Timer; + +const N = 1_000_000; + +const KiB = 1024; +const MiB = 1024 * KiB; +const GiB = 1024 * MiB; + +const ResultCount = struct { + count: usize, + throughput: u64, +}; + +fn benchmarkCodepointCount(buf: []const u8) !ResultCount { + var timer = try Timer.start(); + + const bytes = N * buf.len; + + const start = timer.lap(); + var i: usize = 0; + var r: usize = undefined; + while (i < N) : (i += 1) { + r = try @call( + .{ .modifier = .never_inline }, + std.unicode.utf8CountCodepoints, + .{buf}, + ); + } + const end = timer.read(); + + const elapsed_s = @intToFloat(f64, end - start) / time.ns_per_s; + const throughput = @floatToInt(u64, @intToFloat(f64, bytes) / elapsed_s); + + return ResultCount{ .count = r, .throughput = throughput }; +} pub fn main() !void { const stdout = std.io.getStdOut().outStream(); const args = try std.process.argsAlloc(std.heap.page_allocator); - // Warm up runs - var buffer0: [32767]u16 align(4096) = undefined; - _ = try std.unicode.utf8ToUtf16Le(&buffer0, args[1]); - _ = try std.unicode.utf8ToUtf16Le_better(&buffer0, args[1]); + try stdout.print("short ASCII strings\n", .{}); + { + const result = try benchmarkCodepointCount("abc"); + try stdout.print(" count: {:5} MiB/s [{d}]\n", .{ result.throughput / (1 * MiB), result.count }); + } - @fence(.SeqCst); - var timer = try std.time.Timer.start(); - @fence(.SeqCst); + try stdout.print("short Unicode strings\n", .{}); + { + const result = try benchmarkCodepointCount("ŌŌŌ"); + try stdout.print(" count: {:5} MiB/s [{d}]\n", .{ result.throughput / (1 * MiB), result.count }); + } - var buffer1: [32767]u16 align(4096) = undefined; - _ = try std.unicode.utf8ToUtf16Le(&buffer1, args[1]); + try stdout.print("pure ASCII strings\n", .{}); + { + const result = try benchmarkCodepointCount("hello" ** 16); + try stdout.print(" count: {:5} MiB/s [{d}]\n", .{ result.throughput / (1 * MiB), result.count }); + } - @fence(.SeqCst); - const elapsed_ns_orig = timer.lap(); - @fence(.SeqCst); + try stdout.print("pure Unicode strings\n", .{}); + { + const result = try benchmarkCodepointCount("こんにちは" ** 16); + try stdout.print(" count: {:5} MiB/s [{d}]\n", .{ result.throughput / (1 * MiB), result.count }); + } - var buffer2: [32767]u16 align(4096) = undefined; - _ = try std.unicode.utf8ToUtf16Le_better(&buffer2, args[1]); - - @fence(.SeqCst); - const elapsed_ns_better = timer.lap(); - @fence(.SeqCst); - - std.debug.warn("original utf8ToUtf16Le: elapsed: {} ns ({} ms)\n", .{ - elapsed_ns_orig, elapsed_ns_orig / 1000000, - }); - std.debug.warn("new utf8ToUtf16Le: elapsed: {} ns ({} ms)\n", .{ - elapsed_ns_better, elapsed_ns_better / 1000000, - }); - asm volatile ("nop" - : - : [a] "r" (&buffer1), - [b] "r" (&buffer2) - : "memory" - ); + try stdout.print("mixed ASCII/Unicode strings\n", .{}); + { + const result = try benchmarkCodepointCount("Hyvää huomenta" ** 16); + try stdout.print(" count: {:5} MiB/s [{d}]\n", .{ result.throughput / (1 * MiB), result.count }); + } }