Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/proteus-traits/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ pub trait PreKeyStore {
type Error: ProteusErrorCode;

/// Lookup prekey by ID.
async fn prekey(&mut self, id: RawPreKeyId) -> Result<Option<RawPreKey>, Self::Error>;
async fn prekey(&self, id: RawPreKeyId) -> Result<Option<RawPreKey>, Self::Error>;

/// Remove prekey by ID.
async fn remove(&mut self, id: RawPreKeyId) -> Result<(), Self::Error>;
async fn remove(&self, id: RawPreKeyId) -> Result<(), Self::Error>;
}
13 changes: 5 additions & 8 deletions src/internal/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,11 @@ impl PreKey {
}
}

#[must_use]
pub fn gen_prekeys(start: PreKeyId, size: u16) -> Vec<PreKey> {
(1..)
.map(|i| (u32::from(start.value()) + i) % u32::from(MAX_PREKEY_ID.value()))
.map(|i| PreKey::new(PreKeyId::new(i as u16)))
.take(size as usize)
.collect()
pub fn gen_prekeys(start: PreKeyId, size: u16) -> impl Iterator<Item = PreKey> {
(1..=size).map(move |i| {
let id = (start.value() as u32 + i as u32) % MAX_PREKEY_ID.value() as u32;
PreKey::new(PreKeyId::new(id as _))
})
}

// Prekey bundle ////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -751,7 +749,6 @@ mod tests {
#[wasm_bindgen_test]
fn prekey_generation() {
let k = gen_prekeys(PreKeyId::new(0xFFFC), 5)
.iter()
.map(|k| k.key_id.value())
.collect::<Vec<_>>();
assert_eq!(vec![0xFFFD, 0xFFFE, 0, 1, 2], k)
Expand Down
170 changes: 70 additions & 100 deletions src/internal/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,8 @@ impl<I: Borrow<IdentityKeyPair>> Session<I> {
alice_ident: &m.identity_key,
alice_base: &m.base_key,
})
.map(Some)} else {
.map(Some)
} else {
Ok(None)
}
}
Expand Down Expand Up @@ -1084,6 +1085,7 @@ mod tests {
use std::borrow::Borrow;
use std::collections::BTreeMap;
use std::fmt;
use std::sync::{Mutex, MutexGuard};
use std::vec::Vec;
use wasm_bindgen_test::wasm_bindgen_test;

Expand All @@ -1098,14 +1100,14 @@ mod tests {
}
}

#[derive(Debug)]
#[derive(Debug, Default)]
struct TestStore {
prekeys: Vec<PreKey>,
prekeys: Mutex<Vec<PreKey>>,
}

impl TestStore {
pub fn prekey_slice(&self) -> &[PreKey] {
&self.prekeys
pub fn lock(&self) -> MutexGuard<'_, Vec<PreKey>> {
self.prekeys.lock().expect("propagate any mutex poison")
}
}

Expand All @@ -1115,25 +1117,37 @@ mod tests {
type Error = DummyError;

async fn prekey(
&mut self,
&self,
id: proteus_traits::RawPreKeyId,
) -> Result<Option<proteus_traits::RawPreKey>, Self::Error> {
if let Some(prekey) = self.prekeys.iter().find(|k| k.key_id.value() == id) {
Ok(Some(prekey.serialise().unwrap()))
} else {
Ok(None)
}
self.lock()
.iter()
.find(|k| k.key_id.value() == id)
.map(|prekey| Ok(prekey.serialise().unwrap()))
.transpose()
}

async fn remove(&mut self, id: proteus_traits::RawPreKeyId) -> Result<(), Self::Error> {
self.prekeys
async fn remove(&self, id: proteus_traits::RawPreKeyId) -> Result<(), Self::Error> {
let mut guard = self.lock();
guard
.iter()
.position(|k| k.key_id.value() == id)
.map(|ix| self.prekeys.swap_remove(ix));
.map(|ix| guard.swap_remove(ix));
Ok(())
}
}

impl FromIterator<PreKey> for TestStore {
fn from_iter<I: IntoIterator<Item = PreKey>>(iter: I) -> Self {
let store = TestStore::default();
{
let mut guard = store.lock();
guard.extend(iter);
}
store
Comment thread
coriolinus marked this conversation as resolved.
}
}

#[derive(Debug, Copy, Clone, PartialEq)]
enum MsgType {
Plain,
Expand All @@ -1147,12 +1161,10 @@ mod tests {

let alice_ident = IdentityKeyPair::new();
let bob_ident = IdentityKeyPair::new();
let mut bob_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), total_size as u16),
};
let mut bob_store = gen_prekeys(PreKeyId::new(0), total_size as u16).collect::<TestStore>();

let mut alices = Vec::new();
for pk in bob_store.prekey_slice() {
for pk in bob_store.lock().iter() {
let bob_bundle = PreKeyBundle::new(bob_ident.public_key.clone(), pk);
alices.push(Session::init_from_prekey::<()>(&alice_ident, bob_bundle).unwrap());
}
Expand Down Expand Up @@ -1193,14 +1205,10 @@ mod tests {
let alice_ident = IdentityKeyPair::new();
let bob_ident = IdentityKeyPair::new();

let mut alice_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let mut bob_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let mut alice_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();
let mut bob_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();

let bob_prekey = bob_store.prekey_slice().first().unwrap().clone();
let bob_prekey = bob_store.lock().first().unwrap().clone();
let bob_bundle = PreKeyBundle::new(bob_ident.public_key.clone(), &bob_prekey);

let mut alice = Session::init_from_prekey::<()>(&alice_ident, bob_bundle).unwrap();
Expand Down Expand Up @@ -1340,14 +1348,10 @@ mod tests {
let alice_ident = IdentityKeyPair::new();
let bob_ident = IdentityKeyPair::new();

let mut alice_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let mut bob_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let mut alice_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();
let mut bob_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();

let bob_prekey = bob_store.prekey_slice().first().unwrap().clone();
let bob_prekey = bob_store.lock().first().unwrap().clone();
let bob_bundle = PreKeyBundle::new(bob_ident.public_key.clone(), &bob_prekey);

let mut alice = Session::init_from_prekey::<()>(&alice_ident, bob_bundle).unwrap();
Expand Down Expand Up @@ -1441,11 +1445,9 @@ mod tests {
let alice_ident = IdentityKeyPair::new();
let bob_ident = IdentityKeyPair::new();

let mut bob_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let mut bob_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();

let bob_prekey = bob_store.prekey_slice().first().unwrap().clone();
let bob_prekey = bob_store.lock().first().unwrap().clone();
let bob_bundle = PreKeyBundle::new(bob_ident.public_key.clone(), &bob_prekey);

let mut alice = Session::init_from_prekey::<()>(&alice_ident, bob_bundle).unwrap();
Expand Down Expand Up @@ -1474,17 +1476,13 @@ mod tests {
let alice_ident = IdentityKeyPair::new();
let bob_ident = IdentityKeyPair::new();

let mut alice_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let mut bob_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let mut alice_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();
let mut bob_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();

let bob_prekey = bob_store.prekey_slice().first().unwrap().clone();
let bob_prekey = bob_store.lock().first().unwrap().clone();
let bob_bundle = PreKeyBundle::new(bob_ident.public_key.clone(), &bob_prekey);

let alice_prekey = alice_store.prekey_slice().first().unwrap().clone();
let alice_prekey = alice_store.lock().first().unwrap().clone();
let alice_bundle = PreKeyBundle::new(alice_ident.public_key.clone(), &alice_prekey);

// Initial simultaneous prekey message
Expand Down Expand Up @@ -1524,17 +1522,13 @@ mod tests {
let alice_ident = IdentityKeyPair::new();
let bob_ident = IdentityKeyPair::new();

let mut alice_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let mut bob_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let mut alice_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();
let mut bob_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();

let bob_prekey = bob_store.prekey_slice().first().unwrap().clone();
let bob_prekey = bob_store.lock().first().unwrap().clone();
let bob_bundle = PreKeyBundle::new(bob_ident.public_key.clone(), &bob_prekey);

let alice_prekey = alice_store.prekey_slice().first().unwrap().clone();
let alice_prekey = alice_store.lock().first().unwrap().clone();
let alice_bundle = PreKeyBundle::new(alice_ident.public_key.clone(), &alice_prekey);

// Initial simultaneous prekey message
Expand Down Expand Up @@ -1598,11 +1592,9 @@ mod tests {
let alice_ident = IdentityKeyPair::new();
let bob_ident = IdentityKeyPair::new();

let bob_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let bob_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();

let bob_prekey = bob_store.prekey_slice().first().unwrap().clone();
let bob_prekey = bob_store.lock().first().unwrap().clone();
let bob_bundle = PreKeyBundle::new(bob_ident.public_key, &bob_prekey);

let alice = Session::init_from_prekey::<()>(&alice_ident, bob_bundle).unwrap();
Expand All @@ -1620,14 +1612,10 @@ mod tests {
let alice_ident = IdentityKeyPair::new();
let bob_ident = IdentityKeyPair::new();

let mut alice_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let mut bob_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let mut alice_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();
let mut bob_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();

let bob_prekey = bob_store.prekey_slice().first().unwrap().clone();
let bob_prekey = bob_store.lock().first().unwrap().clone();
let bob_bundle = PreKeyBundle::new(bob_ident.public_key.clone(), &bob_prekey);

let mut alice = Session::init_from_prekey::<()>(&alice_ident, bob_bundle).unwrap();
Expand Down Expand Up @@ -1682,11 +1670,9 @@ mod tests {
let alice_ident = IdentityKeyPair::new();
let bob_ident = IdentityKeyPair::new();

let mut bob_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let mut bob_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();

let bob_prekey = bob_store.prekey_slice().first().unwrap().clone();
let bob_prekey = bob_store.lock().first().unwrap().clone();
let bob_bundle = PreKeyBundle::new(bob_ident.public_key.clone(), &bob_prekey);

let mut alice = Session::init_from_prekey::<()>(&alice_ident, bob_bundle).unwrap();
Expand All @@ -1709,14 +1695,10 @@ mod tests {
let alice_ident = IdentityKeyPair::new();
let bob_ident = IdentityKeyPair::new();

let mut alice_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let mut bob_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let mut alice_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();
let mut bob_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();

let bob_prekey = bob_store.prekey_slice().first().unwrap().clone();
let bob_prekey = bob_store.lock().first().unwrap().clone();
let bob_bundle = PreKeyBundle::new(bob_ident.public_key.clone(), &bob_prekey);

let mut alice = Session::init_from_prekey::<()>(&alice_ident, bob_bundle).unwrap();
Expand Down Expand Up @@ -1828,11 +1810,9 @@ mod tests {
let bob_ident = IdentityKeyPair::new();
let eve_ident = IdentityKeyPair::new();

let eve_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let eve_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();

let eve_prekey = eve_store.prekey_slice().first().unwrap().clone();
let eve_prekey = eve_store.lock().first().unwrap().clone();
let mut eve_bundle = PreKeyBundle::new(eve_ident.public_key.clone(), &eve_prekey);
let mut eve_bundle_signed = PreKeyBundle::signed(&eve_ident, &eve_prekey);

Expand All @@ -1846,10 +1826,8 @@ mod tests {
assert_eq!(PreKeyAuth::Invalid, eve_bundle_signed.verify());

// authentic prekey
let bob_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 10),
};
let bob_prekey = bob_store.prekey_slice().first().unwrap().clone();
let bob_store = gen_prekeys(PreKeyId::new(0), 10).collect::<TestStore>();
let bob_prekey = bob_store.lock().first().unwrap().clone();
let bob_bundle_signed = PreKeyBundle::signed(&bob_ident, &bob_prekey);
assert_eq!(PreKeyAuth::Valid, bob_bundle_signed.verify());
}
Expand All @@ -1860,9 +1838,7 @@ mod tests {
let alice = IdentityKeyPair::new();
let bob = IdentityKeyPair::new();

let mut bob_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 500),
};
let mut bob_store = gen_prekeys(PreKeyId::new(0), 500).collect::<TestStore>();

async fn get_bob(bob: &IdentityKeyPair, i: u16, store: &mut TestStore) -> PreKeyBundle {
PreKeyBundle::new(
Expand Down Expand Up @@ -1927,13 +1903,9 @@ mod tests {
let alice = IdentityKeyPair::new();
let bob = IdentityKeyPair::new();

let mut bob_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(0), 1),
};
let mut bob_store = gen_prekeys(PreKeyId::new(0), 1).collect::<TestStore>();

let mut alice_store = TestStore {
prekeys: gen_prekeys(PreKeyId::new(1), 1),
};
let mut alice_store = gen_prekeys(PreKeyId::new(1), 1).collect::<TestStore>();

async fn get_bob(bob: &IdentityKeyPair, i: u16, store: &mut TestStore) -> PreKeyBundle {
PreKeyBundle::new(
Expand Down Expand Up @@ -2007,12 +1979,12 @@ mod tests {
let alice_ident = IdentityKeyPair::new();
let bob_ident = IdentityKeyPair::new();

let mut bob_store1 = TestStore {
prekeys: vec![PreKey::new(PreKeyId::new(1))],
};
let mut bob_store2 = TestStore {
prekeys: vec![PreKey::new(PreKeyId::new(1))],
};
let mut bob_store1 = [PreKey::new(PreKeyId::new(1))]
.into_iter()
.collect::<TestStore>();
let mut bob_store2 = [PreKey::new(PreKeyId::new(1))]
.into_iter()
.collect::<TestStore>();

let bob_prekey = PreKey::deserialise(
&bob_store1
Expand Down Expand Up @@ -2053,9 +2025,7 @@ mod tests {
let alice_ident = IdentityKeyPair::new();
let bob_ident = IdentityKeyPair::new();

let mut bob_store = TestStore {
prekeys: vec![PreKey::last_resort()],
};
let mut bob_store = [PreKey::last_resort()].into_iter().collect::<TestStore>();

let bob_prekey = PreKey::deserialise(
&bob_store
Expand Down
Loading
Loading