1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
use std::{
    hash::{BuildHasher, Hash},
    sync::Arc,
};

use crate::cht::SegmentedHashMap;

use async_lock::{Mutex, MutexGuard};
use triomphe::Arc as TrioArc;

const LOCK_MAP_NUM_SEGMENTS: usize = 64;

type LockMap<K, S> = SegmentedHashMap<Arc<K>, TrioArc<Mutex<()>>, S>;

// We need the `where` clause here because of the Drop impl.
pub(crate) struct KeyLock<'a, K, S>
where
    K: Eq + Hash,
    S: BuildHasher,
{
    map: &'a LockMap<K, S>,
    key: Arc<K>,
    hash: u64,
    lock: TrioArc<Mutex<()>>,
}

impl<'a, K, S> Drop for KeyLock<'a, K, S>
where
    K: Eq + Hash,
    S: BuildHasher,
{
    fn drop(&mut self) {
        if TrioArc::count(&self.lock) <= 2 {
            self.map.remove_if(
                self.hash,
                |k| k == &self.key,
                |_k, v| TrioArc::count(v) <= 2,
            );
        }
    }
}

impl<'a, K, S> KeyLock<'a, K, S>
where
    K: Eq + Hash,
    S: BuildHasher,
{
    fn new(map: &'a LockMap<K, S>, key: &Arc<K>, hash: u64, lock: TrioArc<Mutex<()>>) -> Self {
        Self {
            map,
            key: Arc::clone(key),
            hash,
            lock,
        }
    }

    pub(crate) async fn lock(&self) -> MutexGuard<'_, ()> {
        self.lock.lock().await
    }
}

pub(crate) struct KeyLockMap<K, S> {
    locks: LockMap<K, S>,
}

impl<K, S> KeyLockMap<K, S>
where
    K: Eq + Hash,
    S: BuildHasher,
{
    pub(crate) fn with_hasher(hasher: S) -> Self {
        Self {
            locks: SegmentedHashMap::with_num_segments_and_hasher(LOCK_MAP_NUM_SEGMENTS, hasher),
        }
    }

    pub(crate) fn key_lock(&self, key: &Arc<K>) -> KeyLock<'_, K, S> {
        let hash = self.locks.hash(key);
        let kl = TrioArc::new(Mutex::new(()));
        match self
            .locks
            .insert_if_not_present(Arc::clone(key), hash, kl.clone())
        {
            None => KeyLock::new(&self.locks, key, hash, kl),
            Some(existing_kl) => KeyLock::new(&self.locks, key, hash, existing_kl),
        }
    }
}

#[cfg(test)]
impl<K, S> KeyLockMap<K, S> {
    pub(crate) fn is_empty(&self) -> bool {
        self.locks.len() == 0
    }
}