zig/std/hash_map.zig

562 lines
19 KiB
Zig
Raw Normal View History

const std = @import("index.zig");
const debug = std.debug;
const assert = debug.assert;
const math = std.math;
const mem = std.mem;
const Allocator = mem.Allocator;
const builtin = @import("builtin");
const want_modification_safety = builtin.mode != builtin.Mode.ReleaseFast;
const debug_u32 = if (want_modification_safety) u32 else void;
pub fn AutoHashMap(comptime K: type, comptime V: type) type {
return HashMap(K, V, getAutoHashFn(K), getAutoEqlFn(K));
}
2018-05-30 13:09:11 -07:00
pub fn HashMap(comptime K: type, comptime V: type, comptime hash: fn (key: K) u32, comptime eql: fn (a: K, b: K) bool) type {
return struct {
2016-12-18 16:40:26 -08:00
entries: []Entry,
size: usize,
max_distance_from_start_index: usize,
allocator: *Allocator,
2016-12-18 16:40:26 -08:00
// this is used to detect bugs where a hashtable is edited while an iterator is running.
modification_count: debug_u32,
const Self = @This();
2016-12-18 16:40:26 -08:00
pub const KV = struct {
2016-12-18 16:40:26 -08:00
key: K,
value: V,
};
const Entry = struct {
used: bool,
distance_from_start_index: usize,
kv: KV,
};
pub const GetOrPutResult = struct {
kv: *KV,
found_existing: bool,
};
pub const Iterator = struct {
hm: *const Self,
2016-12-18 16:40:26 -08:00
// how many items have we returned
count: usize,
// iterator through the entry array
index: usize,
// used to detect concurrent modification
initial_modification_count: debug_u32,
pub fn next(it: *Iterator) ?*KV {
2016-12-18 16:40:26 -08:00
if (want_modification_safety) {
assert(it.initial_modification_count == it.hm.modification_count); // concurrent modification
}
2016-12-18 16:40:26 -08:00
if (it.count >= it.hm.size) return null;
while (it.index < it.hm.entries.len) : (it.index += 1) {
2016-12-18 16:40:26 -08:00
const entry = &it.hm.entries[it.index];
if (entry.used) {
it.index += 1;
it.count += 1;
return &entry.kv;
2016-12-18 16:40:26 -08:00
}
}
unreachable; // no next item
}
2018-05-03 06:54:33 -07:00
// Reset the iterator to the initial index
pub fn reset(it: *Iterator) void {
2018-05-03 06:54:33 -07:00
it.count = 0;
it.index = 0;
// Resetting the modification count too
it.initial_modification_count = it.hm.modification_count;
}
2016-12-18 16:40:26 -08:00
};
pub fn init(allocator: *Allocator) Self {
return Self{
.entries = []Entry{},
.allocator = allocator,
.size = 0,
.max_distance_from_start_index = 0,
.modification_count = if (want_modification_safety) 0 else {},
};
}
pub fn deinit(hm: Self) void {
hm.allocator.free(hm.entries);
}
pub fn clear(hm: *Self) void {
2016-12-18 16:40:26 -08:00
for (hm.entries) |*entry| {
entry.used = false;
}
2016-12-18 16:40:26 -08:00
hm.size = 0;
hm.max_distance_from_start_index = 0;
hm.incrementModificationCount();
}
pub fn count(self: Self) usize {
return self.size;
2018-05-03 06:54:33 -07:00
}
/// If key exists this function cannot fail.
/// If there is an existing item with `key`, then the result
/// kv pointer points to it, and found_existing is true.
/// Otherwise, puts a new item with undefined value, and
/// the kv pointer points to it. Caller should then initialize
/// the data.
pub fn getOrPut(self: *Self, key: K) !GetOrPutResult {
// TODO this implementation can be improved - we should only
// have to hash once and find the entry once.
if (self.get(key)) |kv| {
return GetOrPutResult{
.kv = kv,
.found_existing = true,
};
}
self.incrementModificationCount();
try self.ensureCapacity();
const put_result = self.internalPut(key);
assert(put_result.old_kv == null);
return GetOrPutResult{
.kv = &put_result.new_entry.kv,
.found_existing = false,
};
}
pub fn getOrPutValue(self: *Self, key: K, value: V) !*KV {
const res = try self.getOrPut(key);
if (!res.found_existing)
res.kv.value = value;
return res.kv;
}
fn ensureCapacity(self: *Self) !void {
if (self.entries.len == 0) {
return self.initCapacity(16);
}
2016-12-30 22:31:23 -08:00
// if we get too full (60%), double the capacity
if (self.size * 5 >= self.entries.len * 3) {
const old_entries = self.entries;
try self.initCapacity(self.entries.len * 2);
2016-12-18 16:40:26 -08:00
// dump all of the old elements into the new table
for (old_entries) |*old_entry| {
if (old_entry.used) {
self.internalPut(old_entry.kv.key).new_entry.kv.value = old_entry.kv.value;
2016-12-18 16:40:26 -08:00
}
}
self.allocator.free(old_entries);
2016-12-18 16:40:26 -08:00
}
}
/// Returns the kv pair that was already there.
pub fn put(self: *Self, key: K, value: V) !?KV {
self.incrementModificationCount();
try self.ensureCapacity();
const put_result = self.internalPut(key);
put_result.new_entry.kv.value = value;
return put_result.old_kv;
2016-12-18 16:40:26 -08:00
}
pub fn get(hm: *const Self, key: K) ?*KV {
if (hm.entries.len == 0) {
return null;
}
2016-12-18 16:40:26 -08:00
return hm.internalGet(key);
}
pub fn contains(hm: *const Self, key: K) bool {
return hm.get(key) != null;
}
pub fn remove(hm: *Self, key: K) ?*KV {
2018-04-10 21:32:42 -07:00
if (hm.entries.len == 0) return null;
2016-12-18 16:40:26 -08:00
hm.incrementModificationCount();
const start_index = hm.keyToIndex(key);
{
var roll_over: usize = 0;
while (roll_over <= hm.max_distance_from_start_index) : (roll_over += 1) {
const index = (start_index + roll_over) % hm.entries.len;
var entry = &hm.entries[index];
if (!entry.used) return null;
if (!eql(entry.kv.key, key)) continue;
while (roll_over < hm.entries.len) : (roll_over += 1) {
const next_index = (start_index + roll_over + 1) % hm.entries.len;
const next_entry = &hm.entries[next_index];
if (!next_entry.used or next_entry.distance_from_start_index == 0) {
entry.used = false;
hm.size -= 1;
return &entry.kv;
}
entry.* = next_entry.*;
entry.distance_from_start_index -= 1;
entry = next_entry;
2016-12-18 16:40:26 -08:00
}
unreachable; // shifting everything in the table
}
}
return null;
2016-12-18 16:40:26 -08:00
}
pub fn iterator(hm: *const Self) Iterator {
return Iterator{
2016-12-18 16:40:26 -08:00
.hm = hm,
.count = 0,
.index = 0,
.initial_modification_count = hm.modification_count,
};
}
pub fn clone(self: Self) !Self {
var other = Self.init(self.allocator);
try other.initCapacity(self.entries.len);
var it = self.iterator();
while (it.next()) |entry| {
assert((try other.put(entry.key, entry.value)) == null);
}
return other;
}
fn initCapacity(hm: *Self, capacity: usize) !void {
hm.entries = try hm.allocator.alloc(Entry, capacity);
2016-12-18 16:40:26 -08:00
hm.size = 0;
hm.max_distance_from_start_index = 0;
for (hm.entries) |*entry| {
entry.used = false;
}
}
fn incrementModificationCount(hm: *Self) void {
2016-12-18 16:40:26 -08:00
if (want_modification_safety) {
hm.modification_count +%= 1;
}
}
const InternalPutResult = struct {
new_entry: *Entry,
old_kv: ?KV,
};
/// Returns a pointer to the new entry.
/// Asserts that there is enough space for the new item.
fn internalPut(self: *Self, orig_key: K) InternalPutResult {
2016-12-18 16:40:26 -08:00
var key = orig_key;
var value: V = undefined;
const start_index = self.keyToIndex(key);
2016-12-18 16:40:26 -08:00
var roll_over: usize = 0;
var distance_from_start_index: usize = 0;
var got_result_entry = false;
var result = InternalPutResult{
.new_entry = undefined,
.old_kv = null,
};
while (roll_over < self.entries.len) : ({
roll_over += 1;
distance_from_start_index += 1;
}) {
const index = (start_index + roll_over) % self.entries.len;
const entry = &self.entries[index];
2016-12-18 16:40:26 -08:00
if (entry.used and !eql(entry.kv.key, key)) {
2016-12-18 16:40:26 -08:00
if (entry.distance_from_start_index < distance_from_start_index) {
// robin hood to the rescue
const tmp = entry.*;
self.max_distance_from_start_index = math.max(self.max_distance_from_start_index, distance_from_start_index);
if (!got_result_entry) {
got_result_entry = true;
result.new_entry = entry;
}
entry.* = Entry{
2016-12-18 16:40:26 -08:00
.used = true,
.distance_from_start_index = distance_from_start_index,
.kv = KV{
.key = key,
.value = value,
},
2016-12-18 16:40:26 -08:00
};
key = tmp.kv.key;
value = tmp.kv.value;
2016-12-18 16:40:26 -08:00
distance_from_start_index = tmp.distance_from_start_index;
}
continue;
}
if (entry.used) {
result.old_kv = entry.kv;
} else {
2016-12-18 16:40:26 -08:00
// adding an entry. otherwise overwriting old value with
// same key
self.size += 1;
2016-12-18 16:40:26 -08:00
}
self.max_distance_from_start_index = math.max(distance_from_start_index, self.max_distance_from_start_index);
if (!got_result_entry) {
result.new_entry = entry;
}
entry.* = Entry{
2016-12-18 16:40:26 -08:00
.used = true,
.distance_from_start_index = distance_from_start_index,
.kv = KV{
.key = key,
.value = value,
},
2016-12-18 16:40:26 -08:00
};
return result;
2016-12-18 16:40:26 -08:00
}
unreachable; // put into a full map
}
fn internalGet(hm: Self, key: K) ?*KV {
2016-12-18 16:40:26 -08:00
const start_index = hm.keyToIndex(key);
{
var roll_over: usize = 0;
while (roll_over <= hm.max_distance_from_start_index) : (roll_over += 1) {
const index = (start_index + roll_over) % hm.entries.len;
const entry = &hm.entries[index];
if (!entry.used) return null;
if (eql(entry.kv.key, key)) return &entry.kv;
}
}
2016-12-18 16:40:26 -08:00
return null;
}
fn keyToIndex(hm: Self, key: K) usize {
2016-12-18 16:40:26 -08:00
return usize(hash(key)) % hm.entries.len;
}
};
}
test "basic hash map usage" {
2018-04-10 21:32:42 -07:00
var direct_allocator = std.heap.DirectAllocator.init();
defer direct_allocator.deinit();
var map = AutoHashMap(i32, i32).init(&direct_allocator.allocator);
defer map.deinit();
assert((try map.put(1, 11)) == null);
assert((try map.put(2, 22)) == null);
assert((try map.put(3, 33)) == null);
assert((try map.put(4, 44)) == null);
assert((try map.put(5, 55)) == null);
assert((try map.put(5, 66)).?.value == 55);
assert((try map.put(5, 55)).?.value == 66);
const gop1 = try map.getOrPut(5);
assert(gop1.found_existing == true);
assert(gop1.kv.value == 55);
gop1.kv.value = 77;
assert(map.get(5).?.value == 77);
const gop2 = try map.getOrPut(99);
assert(gop2.found_existing == false);
gop2.kv.value = 42;
assert(map.get(99).?.value == 42);
const gop3 = try map.getOrPutValue(5, 5);
assert(gop3.value == 77);
const gop4 = try map.getOrPutValue(100, 41);
assert(gop4.value == 41);
assert(map.contains(2));
assert(map.get(2).?.value == 22);
_ = map.remove(2);
assert(map.remove(2) == null);
assert(map.get(2) == null);
}
2018-05-03 06:54:33 -07:00
test "iterator hash map" {
var direct_allocator = std.heap.DirectAllocator.init();
defer direct_allocator.deinit();
var reset_map = AutoHashMap(i32, i32).init(&direct_allocator.allocator);
2018-05-03 06:54:33 -07:00
defer reset_map.deinit();
assert((try reset_map.put(1, 11)) == null);
assert((try reset_map.put(2, 22)) == null);
assert((try reset_map.put(3, 33)) == null);
2018-05-03 06:54:33 -07:00
var keys = []i32{
3,
2018-08-06 21:54:19 -07:00
2,
1,
};
var values = []i32{
33,
2018-08-06 21:54:19 -07:00
22,
11,
};
2018-05-03 06:54:33 -07:00
var it = reset_map.iterator();
var count: usize = 0;
2018-05-03 06:54:33 -07:00
while (it.next()) |next| {
assert(next.key == keys[count]);
assert(next.value == values[count]);
count += 1;
}
assert(count == 3);
assert(it.next() == null);
it.reset();
count = 0;
while (it.next()) |next| {
assert(next.key == keys[count]);
assert(next.value == values[count]);
count += 1;
if (count == 2) break;
}
it.reset();
var entry = it.next().?;
2018-05-03 06:54:33 -07:00
assert(entry.key == keys[0]);
assert(entry.value == values[0]);
}
pub fn getHashPtrAddrFn(comptime K: type) (fn (K) u32) {
return struct {
fn hash(key: K) u32 {
return getAutoHashFn(usize)(@ptrToInt(key));
}
}.hash;
}
pub fn getTrivialEqlFn(comptime K: type) (fn (K, K) bool) {
return struct {
fn eql(a: K, b: K) bool {
return a == b;
}
}.eql;
}
pub fn getAutoHashFn(comptime K: type) (fn (K) u32) {
return struct {
fn hash(key: K) u32 {
comptime var rng = comptime std.rand.DefaultPrng.init(0);
return autoHash(key, &rng.random, u32);
}
}.hash;
}
pub fn getAutoEqlFn(comptime K: type) (fn (K, K) bool) {
return struct {
fn eql(a: K, b: K) bool {
return autoEql(a, b);
}
}.eql;
}
// TODO improve these hash functions
pub fn autoHash(key: var, comptime rng: *std.rand.Random, comptime HashInt: type) HashInt {
switch (@typeInfo(@typeOf(key))) {
builtin.TypeId.NoReturn,
builtin.TypeId.Opaque,
builtin.TypeId.Undefined,
builtin.TypeId.ArgTuple,
=> @compileError("cannot hash this type"),
builtin.TypeId.Void,
builtin.TypeId.Null,
=> return 0,
builtin.TypeId.Int => |info| {
const unsigned_x = @bitCast(@IntType(false, info.bits), key);
if (info.bits <= HashInt.bit_count) {
return HashInt(unsigned_x) ^ comptime rng.scalar(HashInt);
} else {
return @truncate(HashInt, unsigned_x ^ comptime rng.scalar(@typeOf(unsigned_x)));
}
},
builtin.TypeId.Float => |info| {
return autoHash(@bitCast(@IntType(false, info.bits), key), rng);
},
builtin.TypeId.Bool => return autoHash(@boolToInt(key), rng),
builtin.TypeId.Enum => return autoHash(@enumToInt(key), rng),
builtin.TypeId.ErrorSet => return autoHash(@errorToInt(key), rng),
builtin.TypeId.Promise, builtin.TypeId.Fn => return autoHash(@ptrToInt(key), rng),
builtin.TypeId.Namespace,
builtin.TypeId.BoundFn,
builtin.TypeId.ComptimeFloat,
builtin.TypeId.ComptimeInt,
builtin.TypeId.Type,
=> return 0,
builtin.TypeId.Pointer => |info| switch (info.size) {
builtin.TypeInfo.Pointer.Size.One => @compileError("TODO auto hash for single item pointers"),
builtin.TypeInfo.Pointer.Size.Many => @compileError("TODO auto hash for many item pointers"),
builtin.TypeInfo.Pointer.Size.Slice => {
const interval = std.math.max(1, key.len / 256);
var i: usize = 0;
var h = comptime rng.scalar(HashInt);
while (i < key.len) : (i += interval) {
h ^= autoHash(key[i], rng, HashInt);
}
return h;
},
},
builtin.TypeId.Optional => @compileError("TODO auto hash for optionals"),
builtin.TypeId.Array => @compileError("TODO auto hash for arrays"),
builtin.TypeId.Vector => @compileError("TODO auto hash for vectors"),
builtin.TypeId.Struct => @compileError("TODO auto hash for structs"),
builtin.TypeId.Union => @compileError("TODO auto hash for unions"),
builtin.TypeId.ErrorUnion => @compileError("TODO auto hash for unions"),
}
}
pub fn autoEql(a: var, b: @typeOf(a)) bool {
switch (@typeInfo(@typeOf(a))) {
builtin.TypeId.NoReturn,
builtin.TypeId.Opaque,
builtin.TypeId.Undefined,
builtin.TypeId.ArgTuple,
=> @compileError("cannot test equality of this type"),
builtin.TypeId.Void,
builtin.TypeId.Null,
=> return true,
builtin.TypeId.Bool,
builtin.TypeId.Int,
builtin.TypeId.Float,
builtin.TypeId.ComptimeFloat,
builtin.TypeId.ComptimeInt,
builtin.TypeId.Namespace,
builtin.TypeId.Promise,
builtin.TypeId.Enum,
builtin.TypeId.BoundFn,
builtin.TypeId.Fn,
builtin.TypeId.ErrorSet,
builtin.TypeId.Type,
=> return a == b,
builtin.TypeId.Pointer => |info| switch (info.size) {
builtin.TypeInfo.Pointer.Size.One => @compileError("TODO auto eql for single item pointers"),
builtin.TypeInfo.Pointer.Size.Many => @compileError("TODO auto eql for many item pointers"),
builtin.TypeInfo.Pointer.Size.Slice => {
if (a.len != b.len) return false;
for (a) |a_item, i| {
if (!autoEql(a_item, b[i])) return false;
}
return true;
},
},
builtin.TypeId.Optional => @compileError("TODO auto eql for optionals"),
builtin.TypeId.Array => @compileError("TODO auto eql for arrays"),
builtin.TypeId.Struct => @compileError("TODO auto eql for structs"),
builtin.TypeId.Union => @compileError("TODO auto eql for unions"),
builtin.TypeId.ErrorUnion => @compileError("TODO auto eql for unions"),
builtin.TypeId.Vector => @compileError("TODO auto eql for vectors"),
}
}