diff --git a/cuckoo_internal_test.go b/cuckoo_internal_test.go index fde14ee..f2b24a3 100644 --- a/cuckoo_internal_test.go +++ b/cuckoo_internal_test.go @@ -25,6 +25,6 @@ func TestLoad(t *testing.T) { for i := range 16 { err := table.Put(i, true) assert.NoError(err) - assert.Equal(float64(table.Size())/float64(table.Capacity()), table.load()) + assert.Equal(float64(table.Size())/float64(table.TotalCapacity()), table.load()) } } diff --git a/cuckoo_test.go b/cuckoo_test.go index 3f9c7ec..26bcf71 100644 --- a/cuckoo_test.go +++ b/cuckoo_test.go @@ -64,7 +64,7 @@ func TestStartingCapacity(t *testing.T) { assert := assert.New(t) table := cuckoo.NewTable[int, bool](cuckoo.Capacity(64)) - assert.Equal(uint64(128), table.Capacity()) + assert.Equal(uint64(128), table.TotalCapacity()) } func TestResizeCapacity(t *testing.T) { @@ -74,12 +74,12 @@ func TestResizeCapacity(t *testing.T) { cuckoo.GrowthFactor(2), ) - for table.Capacity() == 16 { + for table.TotalCapacity() == 16 { err := table.Put(rand.Int(), true) assert.NoError(err) } - assert.Equal(uint64(32), table.Capacity()) + assert.Equal(uint64(32), table.TotalCapacity()) } func TestPutMany(t *testing.T) { @@ -128,3 +128,16 @@ func TestRemove(t *testing.T) { assert.True(table.Has(0)) } + +func TestDropItem(t *testing.T) { + assert := assert.New(t) + key, value := 0, true + table := cuckoo.NewTable[int, bool]() + (table.Put(key, value)) + + err := table.Drop(key) + + assert.NoError(err) + assert.Equal(0, table.Size()) + assert.False(table.Has(key)) +} diff --git a/table.go b/table.go index bc56a17..945e102 100644 --- a/table.go +++ b/table.go @@ -16,9 +16,9 @@ type Table[K, V any] struct { minLoadFactor float64 } -// Capacity returns the number of slots allocated for the [Table]. To get the +// TotalCapacity returns the number of slots allocated for the [Table]. To get the // number of slots filled, look at [Table.Size]. -func (t Table[K, V]) Capacity() uint64 { +func (t Table[K, V]) TotalCapacity() uint64 { return t.bucketA.capacity + t.bucketB.capacity } @@ -32,21 +32,21 @@ func log2(n uint64) (m int) { } func (t Table[K, V]) maxEvictions() int { - return 3 * log2(t.Capacity()) + return 3 * log2(t.TotalCapacity()) } func (t Table[K, V]) load() float64 { - return float64(t.Size()) / float64(t.Capacity()) + return float64(t.Size()) / float64(t.TotalCapacity()) } -func (t *Table[K, V]) resize() error { +func (t *Table[K, V]) resize(capacity uint64) error { entries := make([]entry[K, V], 0, t.Size()) for k, v := range t.Entries() { entries = append(entries, entry[K, V]{k, v}) } - t.bucketA.resize(t.growthFactor * t.bucketA.capacity) - t.bucketB.resize(t.growthFactor * t.bucketB.capacity) + t.bucketA.resize(capacity) + t.bucketB.resize(capacity) for _, entry := range entries { if err := t.Put(entry.key, entry.value); err != nil { @@ -99,10 +99,10 @@ func (t *Table[K, V]) Put(key K, value V) (err error) { } if t.load() < t.minLoadFactor { - return fmt.Errorf("bad hash: resize on load %d/%d = %f", t.Size(), t.Capacity(), t.load()) + return fmt.Errorf("bad hash: resize on load %d/%d = %f", t.Size(), t.TotalCapacity(), t.load()) } - if err := t.resize(); err != nil { + if err := t.resize(t.growthFactor * t.bucketA.capacity); err != nil { return err } @@ -116,7 +116,7 @@ func (t Table[K, V]) Drop(key K) (err error) { t.bucketB.drop(key) if t.load() < t.minLoadFactor { - return t.resize() + return t.resize(t.bucketA.capacity / t.growthFactor) } return nil