use crate::base::iana::{Opcode, Rcode};
use crate::base::message::{CopyRecordsError, ShortMessage};
use crate::base::message_builder::{
AdditionalBuilder, MessageBuilder, PushError,
};
use crate::base::opt::{ComposeOptData, LongOptData, OptRecord};
use crate::base::wire::{Composer, ParseError};
use crate::base::{Header, Message, ParsedName, Rtype, StaticCompressor};
use crate::rdata::AllRecordData;
use bytes::Bytes;
use octseq::Octets;
use std::boxed::Box;
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::vec::Vec;
use std::{error, fmt};
use tracing::trace;
#[cfg(feature = "tsig")]
use crate::tsig;
pub trait ComposeRequest: Debug + Send + Sync {
fn append_message<Target: Composer>(
&self,
target: Target,
) -> Result<AdditionalBuilder<Target>, CopyRecordsError>;
fn to_message(&self) -> Result<Message<Vec<u8>>, Error>;
fn to_vec(&self) -> Result<Vec<u8>, Error>;
fn header(&self) -> &Header;
fn header_mut(&mut self) -> &mut Header;
fn set_udp_payload_size(&mut self, value: u16);
fn set_dnssec_ok(&mut self, value: bool);
fn add_opt(
&mut self,
opt: &impl ComposeOptData,
) -> Result<(), LongOptData>;
fn is_answer(&self, answer: &Message<[u8]>) -> bool;
fn dnssec_ok(&self) -> bool;
}
pub trait ComposeRequestMulti: Debug + Send + Sync {
fn append_message<Target: Composer>(
&self,
target: Target,
) -> Result<AdditionalBuilder<Target>, CopyRecordsError>;
fn to_message(&self) -> Result<Message<Vec<u8>>, Error>;
fn header(&self) -> &Header;
fn header_mut(&mut self) -> &mut Header;
fn set_udp_payload_size(&mut self, value: u16);
fn set_dnssec_ok(&mut self, value: bool);
fn add_opt(
&mut self,
opt: &impl ComposeOptData,
) -> Result<(), LongOptData>;
fn is_answer(&self, answer: &Message<[u8]>) -> bool;
fn dnssec_ok(&self) -> bool;
}
pub trait SendRequest<CR> {
fn send_request(
&self,
request_msg: CR,
) -> Box<dyn GetResponse + Send + Sync>;
}
impl<T: SendRequest<RequestMessage<Octs>> + ?Sized, Octs: Octets>
SendRequest<RequestMessage<Octs>> for Box<T>
{
fn send_request(
&self,
request_msg: RequestMessage<Octs>,
) -> Box<dyn GetResponse + Send + Sync> {
(**self).send_request(request_msg)
}
}
pub trait SendRequestMulti<CR> {
fn send_request(
&self,
request_msg: CR,
) -> Box<dyn GetResponseMulti + Send + Sync>;
}
impl<T: SendRequestMulti<RequestMessage<Octs>> + ?Sized, Octs: Octets>
SendRequestMulti<RequestMessage<Octs>> for Box<T>
{
fn send_request(
&self,
request_msg: RequestMessage<Octs>,
) -> Box<dyn GetResponseMulti + Send + Sync> {
(**self).send_request(request_msg)
}
}
pub trait GetResponse: Debug {
fn get_response(
&mut self,
) -> Pin<
Box<
dyn Future<Output = Result<Message<Bytes>, Error>>
+ Send
+ Sync
+ '_,
>,
>;
}
#[allow(clippy::type_complexity)]
pub trait GetResponseMulti: Debug {
fn get_response(
&mut self,
) -> Pin<
Box<
dyn Future<Output = Result<Option<Message<Bytes>>, Error>>
+ Send
+ Sync
+ '_,
>,
>;
}
#[derive(Clone, Debug)]
pub struct RequestMessage<Octs: AsRef<[u8]>> {
msg: Message<Octs>,
header: Header,
opt: Option<OptRecord<Vec<u8>>>,
}
impl<Octs: AsRef<[u8]> + Debug + Octets> RequestMessage<Octs> {
pub fn new(msg: impl Into<Message<Octs>>) -> Result<Self, Error> {
let msg = msg.into();
if msg.header().opcode() == Opcode::QUERY
&& msg.first_question().ok_or(Error::FormError)?.qtype()
== Rtype::AXFR
{
return Err(Error::FormError);
}
let header = msg.header();
Ok(Self {
msg,
header,
opt: None,
})
}
fn opt_mut(&mut self) -> &mut OptRecord<Vec<u8>> {
self.opt.get_or_insert_with(Default::default)
}
fn append_message_impl<Target: Composer>(
&self,
mut target: MessageBuilder<Target>,
) -> Result<AdditionalBuilder<Target>, CopyRecordsError> {
let source = &self.msg;
*target.header_mut() = self.header;
let source = source.question();
let mut target = target.question();
for rr in source {
target.push(rr?)?;
}
let mut source = source.answer()?;
let mut target = target.answer();
for rr in &mut source {
let rr = rr?
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
target.push(rr)?;
}
let mut source =
source.next_section()?.expect("section should be present");
let mut target = target.authority();
for rr in &mut source {
let rr = rr?
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
target.push(rr)?;
}
let source =
source.next_section()?.expect("section should be present");
let mut target = target.additional();
for rr in source {
let rr = rr?;
if rr.rtype() != Rtype::OPT {
let rr = rr
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
target.push(rr)?;
}
}
if let Some(opt) = self.opt.as_ref() {
target.push(opt.as_record())?;
}
Ok(target)
}
fn to_message_impl(&self) -> Result<Message<Vec<u8>>, Error> {
let target =
MessageBuilder::from_target(StaticCompressor::new(Vec::new()))
.expect("Vec is expected to have enough space");
let target = self.append_message_impl(target)?;
let result = target.as_builder().clone();
let msg = Message::from_octets(result.finish().into_target()).expect(
"Message should be able to parse output from MessageBuilder",
);
Ok(msg)
}
}
impl<Octs: AsRef<[u8]> + Debug + Octets + Send + Sync> ComposeRequest
for RequestMessage<Octs>
{
fn append_message<Target: Composer>(
&self,
target: Target,
) -> Result<AdditionalBuilder<Target>, CopyRecordsError> {
let target = MessageBuilder::from_target(target)
.map_err(|_| CopyRecordsError::Push(PushError::ShortBuf))?;
let builder = self.append_message_impl(target)?;
Ok(builder)
}
fn to_vec(&self) -> Result<Vec<u8>, Error> {
let msg = self.to_message()?;
Ok(msg.as_octets().clone())
}
fn to_message(&self) -> Result<Message<Vec<u8>>, Error> {
self.to_message_impl()
}
fn header(&self) -> &Header {
&self.header
}
fn header_mut(&mut self) -> &mut Header {
&mut self.header
}
fn set_udp_payload_size(&mut self, value: u16) {
self.opt_mut().set_udp_payload_size(value);
}
fn set_dnssec_ok(&mut self, value: bool) {
self.opt_mut().set_dnssec_ok(value);
}
fn add_opt(
&mut self,
opt: &impl ComposeOptData,
) -> Result<(), LongOptData> {
self.opt_mut().push(opt).map_err(|e| e.unlimited_buf())
}
fn is_answer(&self, answer: &Message<[u8]>) -> bool {
let answer_header = answer.header();
let answer_hcounts = answer.header_counts();
if !answer_header.qr() || answer_header.id() != self.header.id() {
trace!(
"Wrong QR or ID: QR={}, answer ID={}, self ID={}",
answer_header.qr(),
answer_header.id(),
self.header.id()
);
return false;
}
if answer_header.rcode() != Rcode::NOERROR
&& answer_hcounts.qdcount() == 0
&& answer_hcounts.ancount() == 0
&& answer_hcounts.nscount() == 0
&& answer_hcounts.arcount() == 0
{
return true;
}
if answer_hcounts.qdcount() != self.msg.header_counts().qdcount() {
trace!("Wrong QD count");
false
} else {
let res = answer.question() == self.msg.for_slice().question();
if !res {
trace!("Wrong question");
}
res
}
}
fn dnssec_ok(&self) -> bool {
match &self.opt {
None => false,
Some(opt) => opt.dnssec_ok(),
}
}
}
#[derive(Clone, Debug)]
pub struct RequestMessageMulti<Octs>
where
Octs: AsRef<[u8]>,
{
msg: Message<Octs>,
header: Header,
opt: Option<OptRecord<Vec<u8>>>,
}
impl<Octs: AsRef<[u8]> + Debug + Octets> RequestMessageMulti<Octs> {
pub fn new(msg: impl Into<Message<Octs>>) -> Result<Self, Error> {
let msg = msg.into();
if !msg.is_xfr() {
return Err(Error::FormError);
}
let header = msg.header();
Ok(Self {
msg,
header,
opt: None,
})
}
fn opt_mut(&mut self) -> &mut OptRecord<Vec<u8>> {
self.opt.get_or_insert_with(Default::default)
}
fn append_message_impl<Target: Composer>(
&self,
mut target: MessageBuilder<Target>,
) -> Result<AdditionalBuilder<Target>, CopyRecordsError> {
let source = &self.msg;
*target.header_mut() = self.header;
let source = source.question();
let mut target = target.question();
for rr in source {
target.push(rr?)?;
}
let mut source = source.answer()?;
let mut target = target.answer();
for rr in &mut source {
let rr = rr?
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
target.push(rr)?;
}
let mut source =
source.next_section()?.expect("section should be present");
let mut target = target.authority();
for rr in &mut source {
let rr = rr?
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
target.push(rr)?;
}
let source =
source.next_section()?.expect("section should be present");
let mut target = target.additional();
for rr in source {
let rr = rr?;
if rr.rtype() != Rtype::OPT {
let rr = rr
.into_record::<AllRecordData<_, ParsedName<_>>>()?
.expect("record expected");
target.push(rr)?;
}
}
if let Some(opt) = self.opt.as_ref() {
target.push(opt.as_record())?;
}
Ok(target)
}
fn to_message_impl(&self) -> Result<Message<Vec<u8>>, Error> {
let target =
MessageBuilder::from_target(StaticCompressor::new(Vec::new()))
.expect("Vec is expected to have enough space");
let target = self.append_message_impl(target)?;
let result = target.as_builder().clone();
let msg = Message::from_octets(result.finish().into_target()).expect(
"Message should be able to parse output from MessageBuilder",
);
Ok(msg)
}
}
impl<Octs: AsRef<[u8]> + Debug + Octets + Send + Sync> ComposeRequestMulti
for RequestMessageMulti<Octs>
{
fn append_message<Target: Composer>(
&self,
target: Target,
) -> Result<AdditionalBuilder<Target>, CopyRecordsError> {
let target = MessageBuilder::from_target(target)
.map_err(|_| CopyRecordsError::Push(PushError::ShortBuf))?;
let builder = self.append_message_impl(target)?;
Ok(builder)
}
fn to_message(&self) -> Result<Message<Vec<u8>>, Error> {
self.to_message_impl()
}
fn header(&self) -> &Header {
&self.header
}
fn header_mut(&mut self) -> &mut Header {
&mut self.header
}
fn set_udp_payload_size(&mut self, value: u16) {
self.opt_mut().set_udp_payload_size(value);
}
fn set_dnssec_ok(&mut self, value: bool) {
self.opt_mut().set_dnssec_ok(value);
}
fn add_opt(
&mut self,
opt: &impl ComposeOptData,
) -> Result<(), LongOptData> {
self.opt_mut().push(opt).map_err(|e| e.unlimited_buf())
}
fn is_answer(&self, answer: &Message<[u8]>) -> bool {
let answer_header = answer.header();
let answer_hcounts = answer.header_counts();
if !answer_header.qr() || answer_header.id() != self.header.id() {
trace!(
"Wrong QR or ID: QR={}, answer ID={}, self ID={}",
answer_header.qr(),
answer_header.id(),
self.header.id()
);
return false;
}
if answer_header.rcode() != Rcode::NOERROR
&& answer_hcounts.qdcount() == 0
&& answer_hcounts.ancount() == 0
&& answer_hcounts.nscount() == 0
&& answer_hcounts.arcount() == 0
{
return true;
}
if self.msg.qtype() == Some(Rtype::AXFR)
&& answer_hcounts.qdcount() == 0
{
true
} else if answer_hcounts.qdcount()
!= self.msg.header_counts().qdcount()
{
trace!("Wrong QD count");
false
} else {
let res = answer.question() == self.msg.for_slice().question();
if !res {
trace!("Wrong question");
}
res
}
}
fn dnssec_ok(&self) -> bool {
match &self.opt {
None => false,
Some(opt) => opt.dnssec_ok(),
}
}
}
#[derive(Clone, Debug)]
pub enum Error {
ConnectionClosed,
OptTooLong,
MessageBuilderPushError,
MessageParseError,
RedundantTransportNotFound,
FormError,
ShortMessage,
StreamLongMessage,
StreamIdleTimeout,
StreamReceiveError,
StreamReadError(Arc<std::io::Error>),
StreamReadTimeout,
StreamTooManyOutstandingQueries,
StreamWriteError(Arc<std::io::Error>),
StreamUnexpectedEndOfData,
WrongReplyForQuery,
NoTransportAvailable,
Dgram(Arc<super::dgram::QueryError>),
#[cfg(feature = "unstable-server-transport")]
ZoneWrite,
#[cfg(feature = "tsig")]
Authentication(tsig::ValidationError),
#[cfg(feature = "unstable-validator")]
Validation(crate::validator::context::Error),
}
impl From<LongOptData> for Error {
fn from(_: LongOptData) -> Self {
Self::OptTooLong
}
}
impl From<ParseError> for Error {
fn from(_: ParseError) -> Self {
Self::MessageParseError
}
}
impl From<ShortMessage> for Error {
fn from(_: ShortMessage) -> Self {
Self::ShortMessage
}
}
impl From<super::dgram::QueryError> for Error {
fn from(err: super::dgram::QueryError) -> Self {
Self::Dgram(err.into())
}
}
#[cfg(feature = "unstable-validator")]
impl From<crate::validator::context::Error> for Error {
fn from(err: crate::validator::context::Error) -> Self {
Self::Validation(err)
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::ConnectionClosed => write!(f, "connection closed"),
Error::OptTooLong => write!(f, "OPT record is too long"),
Error::MessageBuilderPushError => {
write!(f, "PushError from MessageBuilder")
}
Error::MessageParseError => write!(f, "ParseError from Message"),
Error::RedundantTransportNotFound => write!(
f,
"Underlying transport not found in redundant connection"
),
Error::ShortMessage => {
write!(f, "octet sequence to short to be a valid message")
}
Error::FormError => {
write!(f, "message violates a constraint")
}
Error::StreamLongMessage => {
write!(f, "message too long for stream transport")
}
Error::StreamIdleTimeout => {
write!(f, "stream was idle for too long")
}
Error::StreamReceiveError => write!(f, "error receiving a reply"),
Error::StreamReadError(_) => {
write!(f, "error reading from stream")
}
Error::StreamReadTimeout => {
write!(f, "timeout reading from stream")
}
Error::StreamTooManyOutstandingQueries => {
write!(f, "too many outstanding queries on stream")
}
Error::StreamWriteError(_) => {
write!(f, "error writing to stream")
}
Error::StreamUnexpectedEndOfData => {
write!(f, "unexpected end of data")
}
Error::WrongReplyForQuery => {
write!(f, "reply does not match query")
}
Error::NoTransportAvailable => {
write!(f, "no transport available")
}
Error::Dgram(err) => fmt::Display::fmt(err, f),
#[cfg(feature = "unstable-server-transport")]
Error::ZoneWrite => write!(f, "error writing to zone"),
#[cfg(feature = "tsig")]
Error::Authentication(err) => fmt::Display::fmt(err, f),
#[cfg(feature = "unstable-validator")]
Error::Validation(_) => {
write!(f, "error validating response")
}
}
}
}
impl From<CopyRecordsError> for Error {
fn from(err: CopyRecordsError) -> Self {
match err {
CopyRecordsError::Parse(_) => Self::MessageParseError,
CopyRecordsError::Push(_) => Self::MessageBuilderPushError,
}
}
}
impl error::Error for Error {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match self {
Error::ConnectionClosed => None,
Error::OptTooLong => None,
Error::MessageBuilderPushError => None,
Error::MessageParseError => None,
Error::RedundantTransportNotFound => None,
Error::ShortMessage => None,
Error::FormError => None,
Error::StreamLongMessage => None,
Error::StreamIdleTimeout => None,
Error::StreamReceiveError => None,
Error::StreamReadError(e) => Some(e),
Error::StreamReadTimeout => None,
Error::StreamTooManyOutstandingQueries => None,
Error::StreamWriteError(e) => Some(e),
Error::StreamUnexpectedEndOfData => None,
Error::WrongReplyForQuery => None,
Error::NoTransportAvailable => None,
Error::Dgram(err) => Some(err),
#[cfg(feature = "unstable-server-transport")]
Error::ZoneWrite => None,
#[cfg(feature = "tsig")]
Error::Authentication(e) => Some(e),
#[cfg(feature = "unstable-validator")]
Error::Validation(e) => Some(e),
}
}
}