diff --git a/std/unicode.zig b/std/unicode.zig index df62e9162..81bbc2aab 100644 --- a/std/unicode.zig +++ b/std/unicode.zig @@ -1,4 +1,5 @@ const std = @import("./index.zig"); +const debug = std.debug; /// Given the first byte of a UTF-8 codepoint, /// returns a number 1-4 indicating the total length of the codepoint in bytes. @@ -25,8 +26,8 @@ pub fn utf8Decode(bytes: []const u8) !u32 { }; } pub fn utf8Decode2(bytes: []const u8) !u32 { - std.debug.assert(bytes.len == 2); - std.debug.assert(bytes[0] & 0b11100000 == 0b11000000); + debug.assert(bytes.len == 2); + debug.assert(bytes[0] & 0b11100000 == 0b11000000); var value: u32 = bytes[0] & 0b00011111; if (bytes[1] & 0b11000000 != 0b10000000) return error.Utf8ExpectedContinuation; @@ -38,8 +39,8 @@ pub fn utf8Decode2(bytes: []const u8) !u32 { return value; } pub fn utf8Decode3(bytes: []const u8) !u32 { - std.debug.assert(bytes.len == 3); - std.debug.assert(bytes[0] & 0b11110000 == 0b11100000); + debug.assert(bytes.len == 3); + debug.assert(bytes[0] & 0b11110000 == 0b11100000); var value: u32 = bytes[0] & 0b00001111; if (bytes[1] & 0b11000000 != 0b10000000) return error.Utf8ExpectedContinuation; @@ -56,8 +57,8 @@ pub fn utf8Decode3(bytes: []const u8) !u32 { return value; } pub fn utf8Decode4(bytes: []const u8) !u32 { - std.debug.assert(bytes.len == 4); - std.debug.assert(bytes[0] & 0b11111000 == 0b11110000); + debug.assert(bytes.len == 4); + debug.assert(bytes[0] & 0b11111000 == 0b11110000); var value: u32 = bytes[0] & 0b00000111; if (bytes[1] & 0b11000000 != 0b10000000) return error.Utf8ExpectedContinuation; @@ -78,6 +79,136 @@ pub fn utf8Decode4(bytes: []const u8) !u32 { return value; } +pub fn utf8ValidateSlice(s: []const u8) bool { + var i: usize = 0; + while (i < s.len) { + if (utf8ByteSequenceLength(s[i])) |cp_len| { + if (i + cp_len > s.len) { + return false; + } + + if (utf8Decode(s[i..i+cp_len])) |_| {} else |_| { return false; } + i += cp_len; + } else |err| { + return false; + } + } + return true; +} + +const Utf8View = struct { + bytes: []const u8, + + pub fn init(s: []const u8) !Utf8View { + if (!utf8ValidateSlice(s)) { + return error.InvalidUtf8; + } + + return initUnchecked(s); + } + + pub fn initUnchecked(s: []const u8) Utf8View { + return Utf8View { + .bytes = s, + }; + } + + pub fn initComptime(comptime s: []const u8) Utf8View { + if (comptime init(s)) |r| { + return r; + } else |err| switch (err) { + error.InvalidUtf8 => { + @compileError("invalid utf8"); + unreachable; + } + } + } + + pub fn Iterator(s: &const Utf8View) Utf8Iterator { + return Utf8Iterator { + .bytes = s.bytes, + .i = 0, + }; + } +}; + +const Utf8Iterator = struct { + bytes: []const u8, + i: usize, + + pub fn nextCodepointSlice(it: &Utf8Iterator) ?[]const u8 { + if (it.i >= it.bytes.len) { + return null; + } + + const cp_len = utf8ByteSequenceLength(it.bytes[it.i]) catch unreachable; + + it.i += cp_len; + return it.bytes[it.i-cp_len..it.i]; + } + + pub fn nextCodepoint(it: &Utf8Iterator) ?u32 { + const slice = it.nextCodepointSlice() ?? return null; + + const r = switch (slice.len) { + 1 => u32(slice[0]), + 2 => utf8Decode2(slice), + 3 => utf8Decode3(slice), + 4 => utf8Decode4(slice), + else => unreachable, + }; + + return r catch unreachable; + } +}; + +test "utf8 iterator on ascii" { + const s = Utf8View.initComptime("abc"); + + var it1 = s.Iterator(); + debug.assert(std.mem.eql(u8, "a", ??it1.nextCodepointSlice())); + debug.assert(std.mem.eql(u8, "b", ??it1.nextCodepointSlice())); + debug.assert(std.mem.eql(u8, "c", ??it1.nextCodepointSlice())); + debug.assert(it1.nextCodepointSlice() == null); + + var it2 = s.Iterator(); + debug.assert(??it2.nextCodepoint() == 'a'); + debug.assert(??it2.nextCodepoint() == 'b'); + debug.assert(??it2.nextCodepoint() == 'c'); + debug.assert(it2.nextCodepoint() == null); +} + +test "utf8 view bad" { + // Compile-time error. + // const s3 = Utf8View.initComptime("\xfe\xf2"); + + const s = Utf8View.init("hel\xadlo"); + if (s) |_| { unreachable; } else |err| { debug.assert(err == error.InvalidUtf8); } +} + +test "utf8 view ok" { + const s = Utf8View.initComptime("東京市"); + + var it1 = s.Iterator(); + debug.assert(std.mem.eql(u8, "東", ??it1.nextCodepointSlice())); + debug.assert(std.mem.eql(u8, "京", ??it1.nextCodepointSlice())); + debug.assert(std.mem.eql(u8, "市", ??it1.nextCodepointSlice())); + debug.assert(it1.nextCodepointSlice() == null); + + var it2 = s.Iterator(); + debug.assert(??it2.nextCodepoint() == 0x6771); + debug.assert(??it2.nextCodepoint() == 0x4eac); + debug.assert(??it2.nextCodepoint() == 0x5e02); + debug.assert(it2.nextCodepoint() == null); +} + +test "bad utf8 slice" { + debug.assert(utf8ValidateSlice("abc")); + debug.assert(!utf8ValidateSlice("abc\xc0")); + debug.assert(!utf8ValidateSlice("abc\xc0abc")); + debug.assert(utf8ValidateSlice("abc\xdf\xbf")); +} + test "valid utf8" { testValid("\x00", 0x0); testValid("\x20", 0x20); @@ -145,17 +276,17 @@ fn testError(bytes: []const u8, expected_err: error) void { if (testDecode(bytes)) |_| { unreachable; } else |err| { - std.debug.assert(err == expected_err); + debug.assert(err == expected_err); } } fn testValid(bytes: []const u8, expected_codepoint: u32) void { - std.debug.assert((testDecode(bytes) catch unreachable) == expected_codepoint); + debug.assert((testDecode(bytes) catch unreachable) == expected_codepoint); } fn testDecode(bytes: []const u8) !u32 { const length = try utf8ByteSequenceLength(bytes[0]); if (bytes.len < length) return error.UnexpectedEof; - std.debug.assert(bytes.len == length); + debug.assert(bytes.len == length); return utf8Decode(bytes); }