use crate::error::Error;
use crate::limited_cache;
use crate::msgs::handshake::CertificateChain;
use crate::server;
use crate::server::ClientHello;
use crate::sign;
use crate::webpki::{verify_server_name, ParsedCertificate};
use pki_types::{DnsName, ServerName};
use alloc::string::{String, ToString};
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::fmt::{Debug, Formatter};
use std::collections::HashMap;
use std::sync::Mutex;
#[derive(Debug)]
pub struct NoServerSessionStorage {}
impl server::StoresServerSessions for NoServerSessionStorage {
fn put(&self, _id: Vec<u8>, _sec: Vec<u8>) -> bool {
false
}
fn get(&self, _id: &[u8]) -> Option<Vec<u8>> {
None
}
fn take(&self, _id: &[u8]) -> Option<Vec<u8>> {
None
}
fn can_cache(&self) -> bool {
false
}
}
pub struct ServerSessionMemoryCache {
cache: Mutex<limited_cache::LimitedCache<Vec<u8>, Vec<u8>>>,
}
impl ServerSessionMemoryCache {
pub fn new(size: usize) -> Arc<Self> {
Arc::new(Self {
cache: Mutex::new(limited_cache::LimitedCache::new(size)),
})
}
}
impl server::StoresServerSessions for ServerSessionMemoryCache {
fn put(&self, key: Vec<u8>, value: Vec<u8>) -> bool {
self.cache
.lock()
.unwrap()
.insert(key, value);
true
}
fn get(&self, key: &[u8]) -> Option<Vec<u8>> {
self.cache
.lock()
.unwrap()
.get(key)
.cloned()
}
fn take(&self, key: &[u8]) -> Option<Vec<u8>> {
self.cache.lock().unwrap().remove(key)
}
fn can_cache(&self) -> bool {
true
}
}
impl Debug for ServerSessionMemoryCache {
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.debug_struct("ServerSessionMemoryCache")
.finish()
}
}
#[derive(Debug)]
pub(super) struct NeverProducesTickets {}
impl server::ProducesTickets for NeverProducesTickets {
fn enabled(&self) -> bool {
false
}
fn lifetime(&self) -> u32 {
0
}
fn encrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
None
}
fn decrypt(&self, _bytes: &[u8]) -> Option<Vec<u8>> {
None
}
}
#[derive(Debug)]
pub(super) struct AlwaysResolvesChain(Arc<sign::CertifiedKey>);
impl AlwaysResolvesChain {
pub(super) fn new(private_key: Arc<dyn sign::SigningKey>, chain: CertificateChain) -> Self {
Self(Arc::new(sign::CertifiedKey::new(chain.0, private_key)))
}
pub(super) fn new_with_extras(
private_key: Arc<dyn sign::SigningKey>,
chain: CertificateChain,
ocsp: Vec<u8>,
) -> Self {
let mut r = Self::new(private_key, chain);
{
let cert = Arc::make_mut(&mut r.0);
if !ocsp.is_empty() {
cert.ocsp = Some(ocsp);
}
}
r
}
}
impl server::ResolvesServerCert for AlwaysResolvesChain {
fn resolve(&self, _client_hello: ClientHello) -> Option<Arc<sign::CertifiedKey>> {
Some(Arc::clone(&self.0))
}
}
#[derive(Debug)]
pub struct ResolvesServerCertUsingSni {
by_name: HashMap<String, Arc<sign::CertifiedKey>>,
}
impl ResolvesServerCertUsingSni {
pub fn new() -> Self {
Self {
by_name: HashMap::new(),
}
}
pub fn add(&mut self, name: &str, ck: sign::CertifiedKey) -> Result<(), Error> {
let server_name = {
let checked_name = DnsName::try_from(name)
.map_err(|_| Error::General("Bad DNS name".into()))
.map(|name| name.to_lowercase_owned())?;
ServerName::DnsName(checked_name)
};
ck.end_entity_cert()
.and_then(ParsedCertificate::try_from)
.and_then(|cert| verify_server_name(&cert, &server_name))?;
if let ServerName::DnsName(name) = server_name {
self.by_name
.insert(name.as_ref().to_string(), Arc::new(ck));
}
Ok(())
}
}
impl server::ResolvesServerCert for ResolvesServerCertUsingSni {
fn resolve(&self, client_hello: ClientHello) -> Option<Arc<sign::CertifiedKey>> {
if let Some(name) = client_hello.server_name() {
self.by_name.get(name).cloned()
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::server::ProducesTickets;
use crate::server::ResolvesServerCert;
use crate::server::StoresServerSessions;
use std::vec;
#[test]
fn test_noserversessionstorage_drops_put() {
let c = NoServerSessionStorage {};
assert!(!c.put(vec![0x01], vec![0x02]));
}
#[test]
fn test_noserversessionstorage_denies_gets() {
let c = NoServerSessionStorage {};
c.put(vec![0x01], vec![0x02]);
assert_eq!(c.get(&[]), None);
assert_eq!(c.get(&[0x01]), None);
assert_eq!(c.get(&[0x02]), None);
}
#[test]
fn test_noserversessionstorage_denies_takes() {
let c = NoServerSessionStorage {};
assert_eq!(c.take(&[]), None);
assert_eq!(c.take(&[0x01]), None);
assert_eq!(c.take(&[0x02]), None);
}
#[test]
fn test_serversessionmemorycache_accepts_put() {
let c = ServerSessionMemoryCache::new(4);
assert!(c.put(vec![0x01], vec![0x02]));
}
#[test]
fn test_serversessionmemorycache_persists_put() {
let c = ServerSessionMemoryCache::new(4);
assert!(c.put(vec![0x01], vec![0x02]));
assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
assert_eq!(c.get(&[0x01]), Some(vec![0x02]));
}
#[test]
fn test_serversessionmemorycache_overwrites_put() {
let c = ServerSessionMemoryCache::new(4);
assert!(c.put(vec![0x01], vec![0x02]));
assert!(c.put(vec![0x01], vec![0x04]));
assert_eq!(c.get(&[0x01]), Some(vec![0x04]));
}
#[test]
fn test_serversessionmemorycache_drops_to_maintain_size_invariant() {
let c = ServerSessionMemoryCache::new(2);
assert!(c.put(vec![0x01], vec![0x02]));
assert!(c.put(vec![0x03], vec![0x04]));
assert!(c.put(vec![0x05], vec![0x06]));
assert!(c.put(vec![0x07], vec![0x08]));
assert!(c.put(vec![0x09], vec![0x0a]));
let count = c.get(&[0x01]).iter().count()
+ c.get(&[0x03]).iter().count()
+ c.get(&[0x05]).iter().count()
+ c.get(&[0x07]).iter().count()
+ c.get(&[0x09]).iter().count();
assert!(count < 5);
}
#[test]
fn test_neverproducestickets_does_nothing() {
let npt = NeverProducesTickets {};
assert!(!npt.enabled());
assert_eq!(0, npt.lifetime());
assert_eq!(None, npt.encrypt(&[]));
assert_eq!(None, npt.decrypt(&[]));
}
#[test]
fn test_resolvesservercertusingsni_requires_sni() {
let rscsni = ResolvesServerCertUsingSni::new();
assert!(rscsni
.resolve(ClientHello::new(&None, &[], None, &[]))
.is_none());
}
#[test]
fn test_resolvesservercertusingsni_handles_unknown_name() {
let rscsni = ResolvesServerCertUsingSni::new();
let name = DnsName::try_from("hello.com")
.unwrap()
.to_owned();
assert!(rscsni
.resolve(ClientHello::new(&Some(name), &[], None, &[]))
.is_none());
}
}