summaryrefslogblamecommitdiffstats
path: root/src/state_store.rs
blob: ac90e403a62d800742f03089cdbc220ecda2cd8f (plain) (tree)
1
2
3
4
5
6
7
8
9


                                                                             
                   




                                
                                                            

































































































                                                                                                
                                                                                                    







                                                                  
                                               
                                                           
                                                              
                             
                            






                                                  
                    

                                                          
                                                           






                                                  
             












                                                        
                        






                                                            
                                    





















                                                                              
                          







                                                                              
                                 






















































                                                                                    
use std::{collections::HashMap, ops::{Deref, DerefMut}, sync::{Arc, RwLock}};

use leptos::prelude::*;
use tracing::debug;

// TODO: get rid of this
// V has to be an arc signal
#[derive(Debug)]
pub struct ArcStateStore<K, V> {
    store: Arc<RwLock<HashMap<K, (ArcRwSignal<V>, usize)>>>,
}

impl<K, V> PartialEq for ArcStateStore<K, V> {
    fn eq(&self, other: &Self) -> bool {
        Arc::ptr_eq(&self.store, &other.store)
    }
}

impl<K, V> Clone for ArcStateStore<K, V> {
    fn clone(&self) -> Self {
        Self {
            store: Arc::clone(&self.store),
        }
    }
}

impl<K, V> Eq for ArcStateStore<K, V> {}

impl<K, V> ArcStateStore<K, V> {
    pub fn new() -> Self {
        Self {
            store: Arc::new(RwLock::new(HashMap::new())),
        }
    }
}

#[derive(Debug)]
pub struct StateStore<K, V, S = SyncStorage> {
    inner: ArenaItem<ArcStateStore<K, V>, S>,
}

impl<K, V, S> Dispose for StateStore<K, V, S> {
    fn dispose(self) {
        self.inner.dispose()
    }
}

impl<K, V> StateStore<K, V>
where
    K: Send + Sync + 'static,
    V: Send + Sync + 'static,
{
    pub fn new() -> Self {
        Self::new_with_storage()
    }
}

impl<K, V, S> StateStore<K, V, S>
where
    K: 'static,
    V: 'static,
    S: Storage<ArcStateStore<K, V>>,
{
    pub fn new_with_storage() -> Self {
        Self {
            inner: ArenaItem::new_with_storage(ArcStateStore::new()),
        }
    }
}

impl<K, V> StateStore<K, V, LocalStorage>
where
    K: 'static,
    V: 'static,
{
    pub fn new_local() -> Self {
        Self::new_with_storage()
    }
}

impl<
    K: std::marker::Send + std::marker::Sync + 'static,
    V: std::marker::Send + std::marker::Sync + 'static,
> From<ArcStateStore<K, V>> for StateStore<K, V>
{
    fn from(value: ArcStateStore<K, V>) -> Self {
        Self {
            inner: ArenaItem::new_with_storage(value),
        }
    }
}

impl<K: 'static, V: 'static> FromLocal<ArcStateStore<K, V>> for StateStore<K, V, LocalStorage> {
    fn from_local(value: ArcStateStore<K, V>) -> Self {
        Self {
            inner: ArenaItem::new_with_storage(value),
        }
    }
}

impl<K, V, S> Copy for StateStore<K, V, S> {}

impl<K, V, S> Clone for StateStore<K, V, S> {
    fn clone(&self) -> Self {
        *self
    }
}

impl<K: Eq + std::hash::Hash + Clone + std::fmt::Debug, V: Clone + std::fmt::Debug> StateStore<K, V>
where
    K: Send + Sync + 'static,
    V: Send + Sync + 'static,
{
    pub fn store(&self, key: K, value: V) -> StateListener<K, V> {
        {
            let store = self.inner.try_get_value().unwrap();
            let mut store = store.store.write().unwrap();
            debug!("store state: {:?}", store);
            if let Some((v, count)) = store.get_mut(&key) {
                debug!("updating old value already in store");
                v.set(value);
                *count += 1;
                StateListener {
                    value: v.clone(),
                    cleaner: StateCleaner {
                        key,
                        state_store: self.clone(),
                    },
                }
            } else {
                let v = ArcRwSignal::new(value);
                store.insert(key.clone(), (v.clone(), 1));
                debug!("inserting new value: {:?}", store);
                StateListener {
                    value: v.into(),
                    cleaner: StateCleaner {
                        key,
                        state_store: self.clone(),
                    },
                }
            }
        }
    }
}

impl<K, V> StateStore<K, V>
where
    K: Eq + std::hash::Hash + Send + Sync + 'static,
    V: Send + Sync + 'static,
{
    pub fn update(&self, key: &K, value: V) {
        let store = self.inner.try_get_value().unwrap();
        let mut store = store.store.write().unwrap();
        if let Some((v, _)) = store.get_mut(key) {
            v.set(value)
        }
    }

    pub fn modify(&self, key: &K, modify: impl Fn(&mut V)) {
        let store = self.inner.try_get_value().unwrap();
        let mut store = store.store.write().unwrap();
        if let Some((v, _)) = store.get_mut(key) {
            v.update(|v| modify(v));
        }
    }

    fn remove(&self, key: &K) {
        // let store = self.inner.try_get_value().unwrap();
        // let mut store = store.store.write().unwrap();
        // if let Some((_v, count)) = store.get_mut(key) {
        //     *count -= 1;
        //     if *count == 0 {
        //         store.remove(key);
        //         debug!("dropped item from store");
        //     }
        // }
    }
}

#[derive(Clone)]
pub struct StateListener<K, V>
where
    K: Eq + std::hash::Hash + 'static + std::marker::Send + std::marker::Sync,
    V: 'static + std::marker::Send + std::marker::Sync,
{
    value: ArcRwSignal<V>,
    cleaner: StateCleaner<K, V>,
}

impl<
    K: std::cmp::Eq + std::hash::Hash + std::marker::Send + std::marker::Sync,
    V: std::marker::Send + std::marker::Sync,
> Deref for StateListener<K, V>
{
    type Target = ArcRwSignal<V>;

    fn deref(&self) -> &Self::Target {
        &self.value
    }
}

impl<K: std::cmp::Eq + std::hash::Hash + Send + Sync, V: Send + Sync> DerefMut
    for StateListener<K, V>
{
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.value
    }
}

struct ArcStateCleaner<K, V> {
    key: K,
    state_store: ArcStateStore<K, V>,
}

struct StateCleaner<K, V>
where
    K: Eq + std::hash::Hash + Send + Sync + 'static,
    V: Send + Sync + 'static,
{
    key: K,
    state_store: StateStore<K, V>,
}

impl<K, V> Clone for StateCleaner<K, V>
where
    K: Eq + std::hash::Hash + Clone + Send + Sync,
    V: Send + Sync,
{
    fn clone(&self) -> Self {
        {
            let store = self.state_store.inner.try_get_value().unwrap();
            let mut store = store.store.write().unwrap();
            if let Some((_v, count)) = store.get_mut(&self.key) {
                *count += 1;
            }
        }
        Self {
            key: self.key.clone(),
            state_store: self.state_store.clone(),
        }
    }
}

impl<K: Eq + std::hash::Hash + Send + Sync + 'static, V: Send + Sync + 'static> Drop
    for StateCleaner<K, V>
{
    fn drop(&mut self) {
        self.state_store.remove(&self.key);
    }
}