use core::task::{Context, Poll};
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use futures_util::ready;
use hyper::server::accept::Accept;
use hyper::server::conn::{AddrIncoming, AddrStream};
use rustls::{ServerConfig, ServerConnection};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
mod builder;
pub use builder::AcceptorBuilder;
use builder::WantsTlsConfig;
pub struct TlsAcceptor<A = AddrIncoming> {
config: Arc<ServerConfig>,
acceptor: A,
}
impl TlsAcceptor {
pub fn builder() -> AcceptorBuilder<WantsTlsConfig> {
AcceptorBuilder::new()
}
pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> Self {
Self {
config,
acceptor: incoming,
}
}
}
impl<A> Accept for TlsAcceptor<A>
where
A: Accept<Error = io::Error> + Unpin,
A::Conn: AsyncRead + AsyncWrite + Unpin,
{
type Conn = TlsStream<A::Conn>;
type Error = io::Error;
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
let pin = self.get_mut();
Poll::Ready(match ready!(Pin::new(&mut pin.acceptor).poll_accept(cx)) {
Some(Ok(sock)) => Some(Ok(TlsStream::new(sock, pin.config.clone()))),
Some(Err(e)) => Some(Err(e)),
None => None,
})
}
}
impl<C, I> From<(C, I)> for TlsAcceptor
where
C: Into<Arc<ServerConfig>>,
I: Into<AddrIncoming>,
{
fn from((config, incoming): (C, I)) -> Self {
Self::new(config.into(), incoming.into())
}
}
pub struct TlsStream<C = AddrStream> {
state: State<C>,
}
impl<C: AsyncRead + AsyncWrite + Unpin> TlsStream<C> {
fn new(stream: C, config: Arc<ServerConfig>) -> Self {
let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
Self {
state: State::Handshaking(accept),
}
}
pub fn io(&self) -> Option<&C> {
match &self.state {
State::Handshaking(accept) => accept.get_ref(),
State::Streaming(stream) => Some(stream.get_ref().0),
}
}
pub fn connection(&self) -> Option<&ServerConnection> {
match &self.state {
State::Handshaking(_) => None,
State::Streaming(stream) => Some(stream.get_ref().1),
}
}
}
impl<C: AsyncRead + AsyncWrite + Unpin> AsyncRead for TlsStream<C> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
let pin = self.get_mut();
let accept = match &mut pin.state {
State::Handshaking(accept) => accept,
State::Streaming(stream) => return Pin::new(stream).poll_read(cx, buf),
};
let mut stream = match ready!(Pin::new(accept).poll(cx)) {
Ok(stream) => stream,
Err(err) => return Poll::Ready(Err(err)),
};
let result = Pin::new(&mut stream).poll_read(cx, buf);
pin.state = State::Streaming(stream);
result
}
}
impl<C: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsStream<C> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let pin = self.get_mut();
let accept = match &mut pin.state {
State::Handshaking(accept) => accept,
State::Streaming(stream) => return Pin::new(stream).poll_write(cx, buf),
};
let mut stream = match ready!(Pin::new(accept).poll(cx)) {
Ok(stream) => stream,
Err(err) => return Poll::Ready(Err(err)),
};
let result = Pin::new(&mut stream).poll_write(cx, buf);
pin.state = State::Streaming(stream);
result
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(stream) => Pin::new(stream).poll_flush(cx),
}
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut self.state {
State::Handshaking(_) => Poll::Ready(Ok(())),
State::Streaming(stream) => Pin::new(stream).poll_shutdown(cx),
}
}
}
enum State<C> {
Handshaking(tokio_rustls::Accept<C>),
Streaming(tokio_rustls::server::TlsStream<C>),
}