From 013ada1b59e50bbbab19acab0a79dae72133999a Mon Sep 17 00:00:00 2001 From: LemonBoy Date: Wed, 18 Mar 2020 09:35:44 +0100 Subject: [PATCH] std: More type checks for Thread startFn return type Closes #4756 --- lib/std/thread.zig | 51 ++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 45 insertions(+), 6 deletions(-) diff --git a/lib/std/thread.zig b/lib/std/thread.zig index b2f8a44a4..596a8f3cd 100644 --- a/lib/std/thread.zig +++ b/lib/std/thread.zig @@ -6,6 +6,8 @@ const windows = std.os.windows; const c = std.c; const assert = std.debug.assert; +const bad_startfn_ret = "expected return type of startFn to be 'u8', 'noreturn', 'void', or '!void'"; + pub const Thread = struct { data: Data, @@ -158,15 +160,34 @@ pub const Thread = struct { }; fn threadMain(raw_arg: windows.LPVOID) callconv(.C) windows.DWORD { const arg = if (@sizeOf(Context) == 0) {} else @ptrCast(*Context, @alignCast(@alignOf(Context), raw_arg)).*; + switch (@typeInfo(@TypeOf(startFn).ReturnType)) { - .Int => { - return startFn(arg); + .NoReturn => { + startFn(arg); }, .Void => { startFn(arg); return 0; }, - else => @compileError("expected return type of startFn to be 'u8', 'noreturn', 'void', or '!void'"), + .Int => |info| { + if (info.bits != 8) { + @compileError(bad_startfn_ret); + } + return startFn(arg); + }, + .ErrorUnion => |info| { + if (info.payload != void) { + @compileError(bad_startfn_ret); + } + startFn(arg) catch |err| { + std.debug.warn("error: {}\n", .{@errorName(err)}); + if (@errorReturnTrace()) |trace| { + std.debug.dumpStackTrace(trace.*); + } + }; + return 0; + }, + else => @compileError(bad_startfn_ret), } } }; @@ -202,14 +223,32 @@ pub const Thread = struct { const arg = if (@sizeOf(Context) == 0) {} else @intToPtr(*const Context, ctx_addr).*; switch (@typeInfo(@TypeOf(startFn).ReturnType)) { - .Int => { - return startFn(arg); + .NoReturn => { + startFn(arg); }, .Void => { startFn(arg); return 0; }, - else => @compileError("expected return type of startFn to be 'u8', 'noreturn', 'void', or '!void'"), + .Int => |info| { + if (info.bits != 8) { + @compileError(bad_startfn_ret); + } + return startFn(arg); + }, + .ErrorUnion => |info| { + if (info.payload != void) { + @compileError(bad_startfn_ret); + } + startFn(arg) catch |err| { + std.debug.warn("error: {}\n", .{@errorName(err)}); + if (@errorReturnTrace()) |trace| { + std.debug.dumpStackTrace(trace.*); + } + }; + return 0; + }, + else => @compileError(bad_startfn_ret), } } fn posixThreadMain(ctx: ?*c_void) callconv(.C) ?*c_void {