const std = @import("../std.zig"); const builtin = @import("builtin"); const assert = std.debug.assert; const testing = std.testing; const Loop = std.event.Loop; /// Many producer, many consumer, thread-safe, runtime configurable buffer size. /// When buffer is empty, consumers suspend and are resumed by producers. /// When buffer is full, producers suspend and are resumed by consumers. pub fn Channel(comptime T: type) type { return struct { getters: std.atomic.Queue(GetNode), or_null_queue: std.atomic.Queue(*std.atomic.Queue(GetNode).Node), putters: std.atomic.Queue(PutNode), get_count: usize, put_count: usize, dispatch_lock: bool, need_dispatch: bool, // simple fixed size ring buffer buffer_nodes: []T, buffer_index: usize, buffer_len: usize, const SelfChannel = @This(); const GetNode = struct { tick_node: *Loop.NextTickNode, data: Data, const Data = union(enum) { Normal: Normal, OrNull: OrNull, }; const Normal = struct { ptr: *T, }; const OrNull = struct { ptr: *?T, or_null: *std.atomic.Queue(*std.atomic.Queue(GetNode).Node).Node, }; }; const PutNode = struct { data: T, tick_node: *Loop.NextTickNode, }; const global_event_loop = Loop.instance orelse @compileError("std.event.Channel currently only works with event-based I/O"); /// Call `deinit` to free resources when done. /// `buffer` must live until `deinit` is called. /// For a zero length buffer, use `[0]T{}`. /// TODO https://github.com/ziglang/zig/issues/2765 pub fn init(self: *SelfChannel, buffer: []T) void { // The ring buffer implementation only works with power of 2 buffer sizes // because of relying on subtracting across zero. For example (0 -% 1) % 10 == 5 assert(buffer.len == 0 or @popCount(usize, buffer.len) == 1); self.* = SelfChannel{ .buffer_len = 0, .buffer_nodes = buffer, .buffer_index = 0, .dispatch_lock = false, .need_dispatch = false, .getters = std.atomic.Queue(GetNode).init(), .putters = std.atomic.Queue(PutNode).init(), .or_null_queue = std.atomic.Queue(*std.atomic.Queue(GetNode).Node).init(), .get_count = 0, .put_count = 0, }; } /// Must be called when all calls to put and get have suspended and no more calls occur. /// This can be omitted if caller can guarantee that the suspended putters and getters /// do not need to be run to completion. Note that this may leave awaiters hanging. pub fn deinit(self: *SelfChannel) void { while (self.getters.get()) |get_node| { resume get_node.data.tick_node.data; } while (self.putters.get()) |put_node| { resume put_node.data.tick_node.data; } self.* = undefined; } /// puts a data item in the channel. The function returns when the value has been added to the /// buffer, or in the case of a zero size buffer, when the item has been retrieved by a getter. /// Or when the channel is destroyed. pub fn put(self: *SelfChannel, data: T) void { var my_tick_node = Loop.NextTickNode.init(@frame()); var queue_node = std.atomic.Queue(PutNode).Node.init(PutNode{ .tick_node = &my_tick_node, .data = data, }); suspend { self.putters.put(&queue_node); _ = @atomicRmw(usize, &self.put_count, .Add, 1, .SeqCst); self.dispatch(); } } /// await this function to get an item from the channel. If the buffer is empty, the frame will /// complete when the next item is put in the channel. pub fn get(self: *SelfChannel) callconv(.Async) T { // TODO https://github.com/ziglang/zig/issues/2765 var result: T = undefined; var my_tick_node = Loop.NextTickNode.init(@frame()); var queue_node = std.atomic.Queue(GetNode).Node.init(GetNode{ .tick_node = &my_tick_node, .data = GetNode.Data{ .Normal = GetNode.Normal{ .ptr = &result }, }, }); suspend { self.getters.put(&queue_node); _ = @atomicRmw(usize, &self.get_count, .Add, 1, .SeqCst); self.dispatch(); } return result; } //pub async fn select(comptime EnumUnion: type, channels: ...) EnumUnion { // assert(@memberCount(EnumUnion) == channels.len); // enum union and channels mismatch // assert(channels.len != 0); // enum unions cannot have 0 fields // if (channels.len == 1) { // const result = await (async channels[0].get() catch unreachable); // return @unionInit(EnumUnion, @memberName(EnumUnion, 0), result); // } //} /// Get an item from the channel. If the buffer is empty and there are no /// puts waiting, this returns `null`. pub fn getOrNull(self: *SelfChannel) ?T { // TODO integrate this function with named return values // so we can get rid of this extra result copy var result: ?T = null; var my_tick_node = Loop.NextTickNode.init(@frame()); var or_null_node = std.atomic.Queue(*std.atomic.Queue(GetNode).Node).Node.init(undefined); var queue_node = std.atomic.Queue(GetNode).Node.init(GetNode{ .tick_node = &my_tick_node, .data = GetNode.Data{ .OrNull = GetNode.OrNull{ .ptr = &result, .or_null = &or_null_node, }, }, }); or_null_node.data = &queue_node; suspend { self.getters.put(&queue_node); _ = @atomicRmw(usize, &self.get_count, .Add, 1, .SeqCst); self.or_null_queue.put(&or_null_node); self.dispatch(); } return result; } fn dispatch(self: *SelfChannel) void { // set the "need dispatch" flag @atomicStore(bool, &self.need_dispatch, true, .SeqCst); lock: while (true) { // set the lock flag if (@atomicRmw(bool, &self.dispatch_lock, .Xchg, true, .SeqCst)) return; // clear the need_dispatch flag since we're about to do it @atomicStore(bool, &self.need_dispatch, false, .SeqCst); while (true) { one_dispatch: { // later we correct these extra subtractions var get_count = @atomicRmw(usize, &self.get_count, .Sub, 1, .SeqCst); var put_count = @atomicRmw(usize, &self.put_count, .Sub, 1, .SeqCst); // transfer self.buffer to self.getters while (self.buffer_len != 0) { if (get_count == 0) break :one_dispatch; const get_node = &self.getters.get().?.data; switch (get_node.data) { GetNode.Data.Normal => |info| { info.ptr.* = self.buffer_nodes[(self.buffer_index -% self.buffer_len) % self.buffer_nodes.len]; }, GetNode.Data.OrNull => |info| { _ = self.or_null_queue.remove(info.or_null); info.ptr.* = self.buffer_nodes[(self.buffer_index -% self.buffer_len) % self.buffer_nodes.len]; }, } global_event_loop.onNextTick(get_node.tick_node); self.buffer_len -= 1; get_count = @atomicRmw(usize, &self.get_count, .Sub, 1, .SeqCst); } // direct transfer self.putters to self.getters while (get_count != 0 and put_count != 0) { const get_node = &self.getters.get().?.data; const put_node = &self.putters.get().?.data; switch (get_node.data) { GetNode.Data.Normal => |info| { info.ptr.* = put_node.data; }, GetNode.Data.OrNull => |info| { _ = self.or_null_queue.remove(info.or_null); info.ptr.* = put_node.data; }, } global_event_loop.onNextTick(get_node.tick_node); global_event_loop.onNextTick(put_node.tick_node); get_count = @atomicRmw(usize, &self.get_count, .Sub, 1, .SeqCst); put_count = @atomicRmw(usize, &self.put_count, .Sub, 1, .SeqCst); } // transfer self.putters to self.buffer while (self.buffer_len != self.buffer_nodes.len and put_count != 0) { const put_node = &self.putters.get().?.data; self.buffer_nodes[self.buffer_index % self.buffer_nodes.len] = put_node.data; global_event_loop.onNextTick(put_node.tick_node); self.buffer_index +%= 1; self.buffer_len += 1; put_count = @atomicRmw(usize, &self.put_count, .Sub, 1, .SeqCst); } } // undo the extra subtractions _ = @atomicRmw(usize, &self.get_count, .Add, 1, .SeqCst); _ = @atomicRmw(usize, &self.put_count, .Add, 1, .SeqCst); // All the "get or null" functions should resume now. var remove_count: usize = 0; while (self.or_null_queue.get()) |or_null_node| { remove_count += @boolToInt(self.getters.remove(or_null_node.data)); global_event_loop.onNextTick(or_null_node.data.data.tick_node); } if (remove_count != 0) { _ = @atomicRmw(usize, &self.get_count, .Sub, remove_count, .SeqCst); } // clear need-dispatch flag if (@atomicRmw(bool, &self.need_dispatch, .Xchg, false, .SeqCst)) continue; assert(@atomicRmw(bool, &self.dispatch_lock, .Xchg, false, .SeqCst)); // we have to check again now that we unlocked if (@atomicLoad(bool, &self.need_dispatch, .SeqCst)) continue :lock; return; } } } }; } test "std.event.Channel" { if (!std.io.is_async) return error.SkipZigTest; // https://github.com/ziglang/zig/issues/1908 if (builtin.single_threaded) return error.SkipZigTest; // https://github.com/ziglang/zig/issues/3251 if (builtin.os.tag == .freebsd) return error.SkipZigTest; var channel: Channel(i32) = undefined; channel.init(&[0]i32{}); defer channel.deinit(); var handle = async testChannelGetter(&channel); var putter = async testChannelPutter(&channel); await handle; await putter; } test "std.event.Channel wraparound" { // TODO provide a way to run tests in evented I/O mode if (!std.io.is_async) return error.SkipZigTest; const channel_size = 2; var buf: [channel_size]i32 = undefined; var channel: Channel(i32) = undefined; channel.init(&buf); defer channel.deinit(); // add items to channel and pull them out until // the buffer wraps around, make sure it doesn't crash. var result: i32 = undefined; channel.put(5); testing.expectEqual(@as(i32, 5), channel.get()); channel.put(6); testing.expectEqual(@as(i32, 6), channel.get()); channel.put(7); testing.expectEqual(@as(i32, 7), channel.get()); } fn testChannelGetter(channel: *Channel(i32)) callconv(.Async) void { const value1 = channel.get(); testing.expect(value1 == 1234); const value2 = channel.get(); testing.expect(value2 == 4567); const value3 = channel.getOrNull(); testing.expect(value3 == null); var last_put = async testPut(channel, 4444); const value4 = channel.getOrNull(); testing.expect(value4.? == 4444); await last_put; } fn testChannelPutter(channel: *Channel(i32)) callconv(.Async) void { channel.put(1234); channel.put(4567); } fn testPut(channel: *Channel(i32), value: i32) callconv(.Async) void { channel.put(value); }