#[cfg(feature = "logging")]
use crate::bs_debug;
use crate::check::inappropriate_handshake_message;
use crate::common_state::{CommonState, State};
use crate::conn::ConnectionRandoms;
use crate::crypto::ActiveKeyExchange;
use crate::enums::{AlertDescription, CipherSuite, ContentType, HandshakeType, ProtocolVersion};
use crate::error::{Error, PeerIncompatible, PeerMisbehaved};
use crate::hash_hs::HandshakeHashBuffer;
#[cfg(feature = "logging")]
use crate::log::{debug, trace};
use crate::msgs::base::Payload;
use crate::msgs::enums::{Compression, ExtensionType};
use crate::msgs::enums::{ECPointFormat, PSKKeyExchangeMode};
use crate::msgs::handshake::ConvertProtocolNameList;
use crate::msgs::handshake::{CertificateStatusRequest, ClientSessionTicket};
use crate::msgs::handshake::{ClientExtension, HasServerExtensions};
use crate::msgs::handshake::{ClientHelloPayload, HandshakeMessagePayload, HandshakePayload};
use crate::msgs::handshake::{HelloRetryRequest, KeyShareEntry};
use crate::msgs::handshake::{Random, SessionId};
use crate::msgs::message::{Message, MessagePayload};
use crate::msgs::persist;
use crate::tls13::key_schedule::KeyScheduleEarly;
use crate::SupportedCipherSuite;
#[cfg(feature = "tls12")]
use super::tls12;
use super::Tls12Resumption;
use crate::client::client_conn::ClientConnectionData;
use crate::client::common::ClientHelloDetails;
use crate::client::{tls13, ClientConfig};
use pki_types::{ServerName, UnixTime};
use alloc::borrow::ToOwned;
use alloc::boxed::Box;
use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use core::ops::Deref;
pub(super) type NextState = Box<dyn State<ClientConnectionData>>;
pub(super) type NextStateOrError = Result<NextState, Error>;
pub(super) type ClientContext<'a> = crate::common_state::Context<'a, ClientConnectionData>;
fn find_session(
server_name: &ServerName<'static>,
config: &ClientConfig,
cx: &mut ClientContext<'_>,
) -> Option<persist::Retrieved<ClientSessionValue>> {
#[allow(clippy::let_and_return, clippy::unnecessary_lazy_evaluations)]
let found = config
.resumption
.store
.take_tls13_ticket(server_name)
.map(ClientSessionValue::Tls13)
.or_else(|| {
#[cfg(feature = "tls12")]
{
config
.resumption
.store
.tls12_session(server_name)
.map(ClientSessionValue::Tls12)
}
#[cfg(not(feature = "tls12"))]
None
})
.and_then(|resuming| {
let retrieved = persist::Retrieved::new(resuming, UnixTime::now());
match retrieved.has_expired() {
false => Some(retrieved),
true => None,
}
})
.or_else(|| {
debug!("No cached session for {:?}", server_name);
None
});
if let Some(resuming) = &found {
if cx.common.is_quic() {
cx.common.quic.params = resuming
.tls13()
.map(|v| v.quic_params());
}
}
found
}
pub(super) fn start_handshake(
server_name: ServerName<'static>,
extra_exts: Vec<ClientExtension>,
config: Arc<ClientConfig>,
cx: &mut ClientContext<'_>,
) -> NextStateOrError {
let mut transcript_buffer = HandshakeHashBuffer::new();
if config
.client_auth_cert_resolver
.has_certs()
{
transcript_buffer.set_client_auth_enabled();
}
let mut resuming = find_session(&server_name, &config, cx);
let key_share = if config.supports_version(ProtocolVersion::TLSv1_3) {
Some(tls13::initial_key_share(&config, &server_name)?)
} else {
None
};
#[cfg_attr(not(feature = "tls12"), allow(unused_mut))]
let mut session_id = None;
if let Some(_resuming) = &mut resuming {
#[cfg(feature = "tls12")]
if let ClientSessionValue::Tls12(inner) = &mut _resuming.value {
if !inner.ticket().is_empty() {
inner.session_id = SessionId::random(config.provider.secure_random)?;
}
session_id = Some(inner.session_id);
}
debug!("Resuming session");
} else {
debug!("Not resuming any session");
}
let session_id = match session_id {
Some(session_id) => session_id,
None if cx.common.is_quic() => SessionId::empty(),
None if !config.supports_version(ProtocolVersion::TLSv1_3) => SessionId::empty(),
None => SessionId::random(config.provider.secure_random)?,
};
let random = Random::new(config.provider.secure_random)?;
Ok(emit_client_hello_for_retry(
transcript_buffer,
None,
key_share,
extra_exts,
None,
ClientHelloInput {
config,
resuming,
random,
#[cfg(feature = "tls12")]
using_ems: false,
sent_tls13_fake_ccs: false,
hello: ClientHelloDetails::new(),
session_id,
server_name,
},
cx,
))
}
struct ExpectServerHello {
input: ClientHelloInput,
transcript_buffer: HandshakeHashBuffer,
early_key_schedule: Option<KeyScheduleEarly>,
offered_key_share: Option<Box<dyn ActiveKeyExchange>>,
suite: Option<SupportedCipherSuite>,
}
struct ExpectServerHelloOrHelloRetryRequest {
next: ExpectServerHello,
extra_exts: Vec<ClientExtension>,
}
struct ClientHelloInput {
config: Arc<ClientConfig>,
resuming: Option<persist::Retrieved<ClientSessionValue>>,
random: Random,
#[cfg(feature = "tls12")]
using_ems: bool,
sent_tls13_fake_ccs: bool,
hello: ClientHelloDetails,
session_id: SessionId,
server_name: ServerName<'static>,
}
fn emit_client_hello_for_retry(
mut transcript_buffer: HandshakeHashBuffer,
retryreq: Option<&HelloRetryRequest>,
key_share: Option<Box<dyn ActiveKeyExchange>>,
extra_exts: Vec<ClientExtension>,
suite: Option<SupportedCipherSuite>,
mut input: ClientHelloInput,
cx: &mut ClientContext<'_>,
) -> NextState {
let config = &input.config;
let support_tls12 = config.supports_version(ProtocolVersion::TLSv1_2) && !cx.common.is_quic();
let support_tls13 = config.supports_version(ProtocolVersion::TLSv1_3);
let mut supported_versions = Vec::new();
if support_tls13 {
supported_versions.push(ProtocolVersion::TLSv1_3);
}
if support_tls12 {
supported_versions.push(ProtocolVersion::TLSv1_2);
}
assert!(!supported_versions.is_empty());
let mut exts = vec![
ClientExtension::SupportedVersions(supported_versions),
ClientExtension::EcPointFormats(ECPointFormat::SUPPORTED.to_vec()),
ClientExtension::NamedGroups(
config
.provider
.kx_groups
.iter()
.map(|skxg| skxg.name())
.collect(),
),
ClientExtension::SignatureAlgorithms(
config
.verifier
.supported_verify_schemes(),
),
ClientExtension::ExtendedMasterSecretRequest,
ClientExtension::CertificateStatusRequest(CertificateStatusRequest::build_ocsp()),
];
if let (ServerName::DnsName(dns), true) = (&input.server_name, config.enable_sni) {
exts.push(ClientExtension::make_sni(dns));
}
if let Some(key_share) = &key_share {
debug_assert!(support_tls13);
let key_share = KeyShareEntry::new(key_share.group(), key_share.pub_key());
exts.push(ClientExtension::KeyShare(vec![key_share]));
}
if let Some(cookie) = retryreq.and_then(HelloRetryRequest::get_cookie) {
exts.push(ClientExtension::Cookie(cookie.clone()));
}
if support_tls13 {
let psk_modes = vec![PSKKeyExchangeMode::PSK_DHE_KE];
exts.push(ClientExtension::PresharedKeyModes(psk_modes));
}
if !config.alpn_protocols.is_empty() {
exts.push(ClientExtension::Protocols(Vec::from_slices(
&config
.alpn_protocols
.iter()
.map(|proto| &proto[..])
.collect::<Vec<_>>(),
)));
}
exts.extend(extra_exts.iter().cloned());
let tls13_session = prepare_resumption(&input.resuming, &mut exts, suite, cx, config);
input.hello.sent_extensions = exts
.iter()
.map(ClientExtension::get_type)
.collect();
let mut cipher_suites: Vec<_> = config
.provider
.cipher_suites
.iter()
.filter_map(|cs| match cs.usable_for_protocol(cx.common.protocol) {
true => Some(cs.suite()),
false => None,
})
.collect();
cipher_suites.push(CipherSuite::TLS_EMPTY_RENEGOTIATION_INFO_SCSV);
let mut chp = HandshakeMessagePayload {
typ: HandshakeType::ClientHello,
payload: HandshakePayload::ClientHello(ClientHelloPayload {
client_version: ProtocolVersion::TLSv1_2,
random: input.random,
session_id: input.session_id,
cipher_suites,
compression_methods: vec![Compression::Null],
extensions: exts,
}),
};
let early_key_schedule = if let Some(resuming) = tls13_session {
let schedule = tls13::fill_in_psk_binder(&resuming, &transcript_buffer, &mut chp);
Some((resuming.suite(), schedule))
} else {
None
};
let ch = Message {
version: if retryreq.is_some() {
ProtocolVersion::TLSv1_2
} else {
ProtocolVersion::TLSv1_0
},
payload: MessagePayload::handshake(chp),
};
if retryreq.is_some() {
tls13::emit_fake_ccs(&mut input.sent_tls13_fake_ccs, cx.common);
}
trace!("Sending ClientHello {:#?}", ch);
transcript_buffer.add_message(&ch);
cx.common.send_msg(ch, false);
let early_key_schedule = early_key_schedule.map(|(resuming_suite, schedule)| {
if !cx.data.early_data.is_enabled() {
return schedule;
}
tls13::derive_early_traffic_secret(
&*config.key_log,
cx,
resuming_suite,
&schedule,
&mut input.sent_tls13_fake_ccs,
&transcript_buffer,
&input.random.0,
);
schedule
});
let next = ExpectServerHello {
input,
transcript_buffer,
early_key_schedule,
offered_key_share: key_share,
suite,
};
if support_tls13 && retryreq.is_none() {
Box::new(ExpectServerHelloOrHelloRetryRequest { next, extra_exts })
} else {
Box::new(next)
}
}
fn prepare_resumption<'a>(
resuming: &'a Option<persist::Retrieved<ClientSessionValue>>,
exts: &mut Vec<ClientExtension>,
suite: Option<SupportedCipherSuite>,
cx: &mut ClientContext<'_>,
config: &ClientConfig,
) -> Option<persist::Retrieved<&'a persist::Tls13ClientSessionValue>> {
let resuming = match resuming {
Some(resuming) if !resuming.ticket().is_empty() => resuming,
_ => {
if config.supports_version(ProtocolVersion::TLSv1_3)
|| config.resumption.tls12_resumption == Tls12Resumption::SessionIdOrTickets
{
exts.push(ClientExtension::SessionTicket(ClientSessionTicket::Request));
}
return None;
}
};
let tls13 = match resuming.map(|csv| csv.tls13()) {
Some(tls13) => tls13,
None => {
if config.supports_version(ProtocolVersion::TLSv1_2)
&& config.resumption.tls12_resumption == Tls12Resumption::SessionIdOrTickets
{
exts.push(ClientExtension::SessionTicket(ClientSessionTicket::Offer(
Payload::new(resuming.ticket()),
)));
}
return None; }
};
if !config.supports_version(ProtocolVersion::TLSv1_3) {
return None;
}
let suite = match suite {
Some(SupportedCipherSuite::Tls13(suite)) => Some(suite),
#[cfg(feature = "tls12")]
Some(SupportedCipherSuite::Tls12(_)) => return None,
None => None,
};
if let Some(suite) = suite {
suite.can_resume_from(tls13.suite())?;
}
tls13::prepare_resumption(config, cx, &tls13, exts, suite.is_some());
Some(tls13)
}
pub(super) fn process_alpn_protocol(
common: &mut CommonState,
config: &ClientConfig,
proto: Option<&[u8]>,
) -> Result<(), Error> {
common.alpn_protocol = proto.map(ToOwned::to_owned);
if let Some(alpn_protocol) = &common.alpn_protocol {
if !config
.alpn_protocols
.contains(alpn_protocol)
{
return Err(common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::SelectedUnofferedApplicationProtocol,
));
}
}
if common.is_quic() && common.alpn_protocol.is_none() && !config.alpn_protocols.is_empty() {
return Err(common.send_fatal_alert(
AlertDescription::NoApplicationProtocol,
Error::NoApplicationProtocol,
));
}
debug!(
"ALPN protocol is {:?}",
common
.alpn_protocol
.as_ref()
.map(|v| bs_debug::BsDebug(v))
);
Ok(())
}
impl State<ClientConnectionData> for ExpectServerHello {
fn handle(mut self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> NextStateOrError {
let server_hello =
require_handshake_msg!(m, HandshakeType::ServerHello, HandshakePayload::ServerHello)?;
trace!("We got ServerHello {:#?}", server_hello);
use crate::ProtocolVersion::{TLSv1_2, TLSv1_3};
let config = &self.input.config;
let tls13_supported = config.supports_version(TLSv1_3);
let server_version = if server_hello.legacy_version == TLSv1_2 {
server_hello
.get_supported_versions()
.unwrap_or(server_hello.legacy_version)
} else {
server_hello.legacy_version
};
let version = match server_version {
TLSv1_3 if tls13_supported => TLSv1_3,
TLSv1_2 if config.supports_version(TLSv1_2) => {
if cx.data.early_data.is_enabled() && cx.common.early_traffic {
return Err(PeerMisbehaved::OfferedEarlyDataWithOldProtocolVersion.into());
}
if server_hello
.get_supported_versions()
.is_some()
{
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::SelectedTls12UsingTls13VersionExtension,
)
});
}
TLSv1_2
}
_ => {
let reason = match server_version {
TLSv1_2 | TLSv1_3 => PeerIncompatible::ServerTlsVersionIsDisabledByOurConfig,
_ => PeerIncompatible::ServerDoesNotSupportTls12Or13,
};
return Err(cx
.common
.send_fatal_alert(AlertDescription::ProtocolVersion, reason));
}
};
if server_hello.compression_method != Compression::Null {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::SelectedUnofferedCompression,
)
});
}
if server_hello.has_duplicate_extension() {
return Err(cx.common.send_fatal_alert(
AlertDescription::DecodeError,
PeerMisbehaved::DuplicateServerHelloExtensions,
));
}
let allowed_unsolicited = [ExtensionType::RenegotiationInfo];
if self
.input
.hello
.server_sent_unsolicited_extensions(&server_hello.extensions, &allowed_unsolicited)
{
return Err(cx.common.send_fatal_alert(
AlertDescription::UnsupportedExtension,
PeerMisbehaved::UnsolicitedServerHelloExtension,
));
}
cx.common.negotiated_version = Some(version);
if !cx.common.is_tls13() {
process_alpn_protocol(cx.common, config, server_hello.get_alpn_protocol())?;
}
if let Some(point_fmts) = server_hello.get_ecpoints_extension() {
if !point_fmts.contains(&ECPointFormat::Uncompressed) {
return Err(cx.common.send_fatal_alert(
AlertDescription::HandshakeFailure,
PeerMisbehaved::ServerHelloMustOfferUncompressedEcPoints,
));
}
}
let suite = config
.find_cipher_suite(server_hello.cipher_suite)
.ok_or_else(|| {
cx.common.send_fatal_alert(
AlertDescription::HandshakeFailure,
PeerMisbehaved::SelectedUnofferedCipherSuite,
)
})?;
if version != suite.version().version {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::SelectedUnusableCipherSuiteForVersion,
)
});
}
match self.suite {
Some(prev_suite) if prev_suite != suite => {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::SelectedDifferentCipherSuiteAfterRetry,
)
});
}
_ => {
debug!("Using ciphersuite {:?}", suite);
self.suite = Some(suite);
cx.common.suite = Some(suite);
}
}
let mut transcript = self
.transcript_buffer
.start_hash(suite.hash_provider());
transcript.add_message(&m);
let randoms = ConnectionRandoms::new(self.input.random, server_hello.random);
match suite {
SupportedCipherSuite::Tls13(suite) => {
#[allow(clippy::bind_instead_of_map)]
let resuming_session = self
.input
.resuming
.and_then(|resuming| match resuming.value {
ClientSessionValue::Tls13(inner) => Some(inner),
#[cfg(feature = "tls12")]
ClientSessionValue::Tls12(_) => None,
});
tls13::handle_server_hello(
self.input.config,
cx,
server_hello,
resuming_session,
self.input.server_name,
randoms,
suite,
transcript,
self.early_key_schedule,
self.input.hello,
self.offered_key_share.unwrap(),
self.input.sent_tls13_fake_ccs,
)
}
#[cfg(feature = "tls12")]
SupportedCipherSuite::Tls12(suite) => {
let resuming_session = self
.input
.resuming
.and_then(|resuming| match resuming.value {
ClientSessionValue::Tls12(inner) => Some(inner),
ClientSessionValue::Tls13(_) => None,
});
tls12::CompleteServerHelloHandling {
config: self.input.config,
resuming_session,
server_name: self.input.server_name,
randoms,
using_ems: self.input.using_ems,
transcript,
}
.handle_server_hello(cx, suite, server_hello, tls13_supported)
}
}
}
}
impl ExpectServerHelloOrHelloRetryRequest {
fn into_expect_server_hello(self) -> NextState {
Box::new(self.next)
}
fn handle_hello_retry_request(
self,
cx: &mut ClientContext<'_>,
m: Message,
) -> NextStateOrError {
let hrr = require_handshake_msg!(
m,
HandshakeType::HelloRetryRequest,
HandshakePayload::HelloRetryRequest
)?;
trace!("Got HRR {:?}", hrr);
cx.common.check_aligned_handshake()?;
let cookie = hrr.get_cookie();
let req_group = hrr.get_requested_key_share_group();
let offered_key_share = self.next.offered_key_share.unwrap();
if cookie.is_none() && req_group == Some(offered_key_share.group()) {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::IllegalHelloRetryRequestWithOfferedGroup,
)
});
}
if let Some(cookie) = cookie {
if cookie.0.is_empty() {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::IllegalHelloRetryRequestWithEmptyCookie,
)
});
}
}
if hrr.has_unknown_extension() {
return Err(cx.common.send_fatal_alert(
AlertDescription::UnsupportedExtension,
PeerIncompatible::ServerSentHelloRetryRequestWithUnknownExtension,
));
}
if hrr.has_duplicate_extension() {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::DuplicateHelloRetryRequestExtensions,
)
});
}
if cookie.is_none() && req_group.is_none() {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::IllegalHelloRetryRequestWithNoChanges,
)
});
}
if hrr.session_id != self.next.input.session_id {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::IllegalHelloRetryRequestWithWrongSessionId,
)
});
}
match hrr.get_supported_versions() {
Some(ProtocolVersion::TLSv1_3) => {
cx.common.negotiated_version = Some(ProtocolVersion::TLSv1_3);
}
_ => {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::IllegalHelloRetryRequestWithUnsupportedVersion,
)
});
}
}
let config = &self.next.input.config;
let cs = match config.find_cipher_suite(hrr.cipher_suite) {
Some(cs) => cs,
None => {
return Err({
cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::IllegalHelloRetryRequestWithUnofferedCipherSuite,
)
});
}
};
cx.common.suite = Some(cs);
let transcript = self
.next
.transcript_buffer
.start_hash(cs.hash_provider());
let mut transcript_buffer = transcript.into_hrr_buffer();
transcript_buffer.add_message(&m);
if cx.data.early_data.is_enabled() {
cx.data.early_data.rejected();
}
let key_share = match req_group {
Some(group) if group != offered_key_share.group() => {
let skxg = match config.find_kx_group(group) {
Some(skxg) => skxg,
None => {
return Err(cx.common.send_fatal_alert(
AlertDescription::IllegalParameter,
PeerMisbehaved::IllegalHelloRetryRequestWithUnofferedNamedGroup,
));
}
};
skxg.start()
.map_err(|_| Error::FailedToGetRandomBytes)?
}
_ => offered_key_share,
};
Ok(emit_client_hello_for_retry(
transcript_buffer,
Some(hrr),
Some(key_share),
self.extra_exts,
Some(cs),
self.next.input,
cx,
))
}
}
impl State<ClientConnectionData> for ExpectServerHelloOrHelloRetryRequest {
fn handle(self: Box<Self>, cx: &mut ClientContext<'_>, m: Message) -> NextStateOrError {
match m.payload {
MessagePayload::Handshake {
parsed:
HandshakeMessagePayload {
payload: HandshakePayload::ServerHello(..),
..
},
..
} => self
.into_expect_server_hello()
.handle(cx, m),
MessagePayload::Handshake {
parsed:
HandshakeMessagePayload {
payload: HandshakePayload::HelloRetryRequest(..),
..
},
..
} => self.handle_hello_retry_request(cx, m),
payload => Err(inappropriate_handshake_message(
&payload,
&[ContentType::Handshake],
&[HandshakeType::ServerHello, HandshakeType::HelloRetryRequest],
)),
}
}
}
enum ClientSessionValue {
Tls13(persist::Tls13ClientSessionValue),
#[cfg(feature = "tls12")]
Tls12(persist::Tls12ClientSessionValue),
}
impl ClientSessionValue {
fn common(&self) -> &persist::ClientSessionCommon {
match self {
Self::Tls13(inner) => &inner.common,
#[cfg(feature = "tls12")]
Self::Tls12(inner) => &inner.common,
}
}
fn tls13(&self) -> Option<&persist::Tls13ClientSessionValue> {
match self {
Self::Tls13(v) => Some(v),
#[cfg(feature = "tls12")]
Self::Tls12(_) => None,
}
}
}
impl Deref for ClientSessionValue {
type Target = persist::ClientSessionCommon;
fn deref(&self) -> &Self::Target {
self.common()
}
}