1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
//! Looking up SRV records.
use super::host::lookup_host;
use crate::base::iana::{Class, Rtype};
use crate::base::message::Message;
use crate::base::name::{Name, ToName, ToRelativeName};
use crate::base::wire::ParseError;
use crate::rdata::{Aaaa, Srv, A};
use crate::resolv::resolver::Resolver;
use core::fmt;
use futures_util::stream::{self, Stream, StreamExt};
use octseq::octets::Octets;
use rand::distributions::{Distribution, Uniform};
use std::net::{IpAddr, SocketAddr};
use std::vec::Vec;
use std::{io, mem, ops};
// Look up SRV record. Three outcomes:
//
// * at least one SRV record with a regular target,
// * one single SRV record with the root target -- no such service,
// * no SRV records at all.
//
// In the first case we have a set of (target, port) pairs which we need to
// resolve further if there was no address records for the target in the
// additional section.
//
// In the second case we have nothing.
//
// In the third case we have a single (target, port) pair with the original
// host and the fallback port which we need to resolve further.
//------------ OctetsVec -----------------------------------------------------
#[cfg(feature = "smallvec")]
type OctetsVec = octseq::octets::SmallOctets;
#[cfg(not(feature = "smallvec"))]
type OctetsVec = Vec<u8>;
//------------ lookup_srv ----------------------------------------------------
/// Creates a future that looks up SRV records.
///
/// The future will use the resolver given in `resolver` to query the
/// DNS for SRV records associated with domain name `name` and service
/// `service`.
///
/// The value returned upon success can be turned into a stream of
/// [`ResolvedSrvItem`]s corresponding to the found SRV records, ordered as per
/// the usage rules defined in [RFC 2782]. If no matching SRV record is found,
/// A/AAAA queries on the bare domain name `name` will be attempted, yielding
/// a single element upon success using the port given by `fallback_port`,
/// typcially the standard port for the service in question.
///
/// Each item in the stream can be turned into an iterator over socket
/// addresses as accepted by, for instance, [`TcpStream::connect`].
///
/// The future resolves to `None` whenever the request service is
/// “decidedly not available” at the requested domain, that is there is a
/// single SRV record with the root label as its target.
///
///[`TcpStream::connect`]: tokio::net::TcpStream::connect
pub async fn lookup_srv(
resolver: &impl Resolver,
service: impl ToRelativeName,
name: impl ToName,
fallback_port: u16,
) -> Result<Option<FoundSrvs>, SrvError> {
let full_name = match (&service).chain(&name) {
Ok(name) => name,
Err(_) => return Err(SrvError::LongName),
};
let answer = resolver.query((full_name, Rtype::SRV)).await?;
FoundSrvs::new(answer.as_ref().for_slice(), name, fallback_port)
}
//------------ FoundSrvs -----------------------------------------------------
/// This is the return type for [`lookup_srv`].
#[derive(Clone, Debug)]
pub struct FoundSrvs {
/// The SRV items we found.
///
/// If this is `Ok(some)`, there were SRV records. If this is `Err(some)`,
/// there wasn’t any SRV records and the sole item is the bare host and
/// fallback port.
items: Result<Vec<SrvItem>, SrvItem>,
}
impl FoundSrvs {
/// Converts the found SRV records into socket addresses.
///
/// The method takes a reference to a resolver and returns a stream of
/// socket addresses in the order prescribed by the SRV records. Each
/// returned item provides the set of addresses for one host.
///
/// Note that if you are using the
/// [`StubResolver`][crate::resolv::stub::StubResolver], you will have to
/// pass in a double reference since [`Resolver`] is implemented for a
/// reference to it and this method requires a reference to that impl
/// being passed. This quirk will be fixed in future versions.
pub fn into_stream<R: Resolver>(
self,
resolver: &R,
) -> impl Stream<Item = Result<ResolvedSrvItem, io::Error>> + '_
where
R::Octets: Octets,
{
// Let’s make a somewhat elaborate single iterator from self.items
// that we can use as the base for the stream: We turn the result into
// two options of the two cases and chain those up.
let iter = match self.items {
Ok(vec) => {
Some(vec.into_iter()).into_iter().flatten().chain(None)
}
Err(one) => None.into_iter().flatten().chain(Some(one)),
};
stream::iter(iter).then(move |item| item.resolve(resolver))
}
/// Converts the value into an iterator over the found SRV records.
///
/// If results were found, this returns them in the order prescribed by
/// the SRV records.
///
/// If not results were found, the iterator will yield a single entry
/// with the bare host and the default fallback port.
pub fn into_srvs(self) -> impl Iterator<Item = Srv<Name<OctetsVec>>> {
let (left, right) = match self.items {
Ok(ok) => (Some(ok.into_iter()), None),
Err(err) => (None, Some(std::iter::once(err))),
};
left.into_iter()
.flatten()
.chain(right.into_iter().flatten())
.map(|item| item.srv)
}
/// Merges all results from `other` into `self`.
///
/// Reorders merged results as if they were from a single query.
pub fn merge(&mut self, other: &Self) {
if self.items.is_err() {
let one =
mem::replace(&mut self.items, Ok(Vec::new())).unwrap_err();
self.items.as_mut().unwrap().push(one);
}
match self.items {
Ok(ref mut items) => {
match other.items {
Ok(ref vec) => items.extend_from_slice(vec),
Err(ref one) => items.push(one.clone()),
}
Self::reorder_items(items);
}
Err(_) => unreachable!(),
}
}
}
impl FoundSrvs {
fn new(
answer: &Message<[u8]>,
fallback_name: impl ToName,
fallback_port: u16,
) -> Result<Option<Self>, SrvError> {
let name =
answer.canonical_name().ok_or(SrvError::MalformedAnswer)?;
let mut items = Self::process_records(answer, &name)?;
if items.is_empty() {
return Ok(Some(FoundSrvs {
items: Err(SrvItem::fallback(fallback_name, fallback_port)),
}));
}
if items.len() == 1 && items[0].target().is_root() {
// Exactly one record with target "." indicates no service.
return Ok(None);
}
// Build results including potentially resolved IP addresses
Self::process_additional(&mut items, answer)?;
Self::reorder_items(&mut items);
Ok(Some(FoundSrvs { items: Ok(items) }))
}
fn process_records(
answer: &Message<[u8]>,
name: &impl ToName,
) -> Result<Vec<SrvItem>, SrvError> {
let mut res = Vec::new();
// XXX We could also error out if any SRV error is broken?
for record in answer.answer()?.limit_to_in::<Srv<_>>().flatten() {
if record.owner() == name {
res.push(SrvItem::from_rdata(record.data()))
}
}
Ok(res)
}
fn process_additional(
items: &mut [SrvItem],
answer: &Message<[u8]>,
) -> Result<(), SrvError> {
let additional = answer.additional()?;
for item in items {
let mut addrs = Vec::new();
for record in additional {
let record = match record {
Ok(record) => record,
Err(_) => continue,
};
if record.class() != Class::IN
|| record.owner() != item.target()
{
continue;
}
if let Ok(Some(record)) = record.to_record::<A>() {
addrs.push(record.data().addr().into())
}
if let Ok(Some(record)) = record.to_record::<Aaaa>() {
addrs.push(record.data().addr().into())
}
}
if !addrs.is_empty() {
item.resolved = Some(addrs)
}
}
Ok(())
}
fn reorder_items(items: &mut [SrvItem]) {
// First, reorder by priority and weight, effectively
// grouping by priority, with weight 0 records at the beginning of
// each group.
items.sort_by_key(|k| (k.priority(), k.weight()));
// Find each group and reorder them using reorder_by_weight
let mut current_prio = 0;
let mut weight_sum = 0;
let mut first_index = 0;
for i in 0..items.len() {
if current_prio != items[i].priority() {
current_prio = items[i].priority();
Self::reorder_by_weight(
&mut items[first_index..i],
weight_sum,
);
weight_sum = 0;
first_index = i;
}
weight_sum += u32::from(items[i].weight());
}
Self::reorder_by_weight(&mut items[first_index..], weight_sum);
}
/// Reorders items in a priority level based on their weight
fn reorder_by_weight(items: &mut [SrvItem], weight_sum: u32) {
let mut rng = rand::thread_rng();
let mut weight_sum = weight_sum;
for i in 0..items.len() {
let range = Uniform::new(0, weight_sum + 1);
let mut sum: u32 = 0;
let pick = range.sample(&mut rng);
for j in 0..items.len() {
sum += u32::from(items[j].weight());
if sum >= pick {
weight_sum -= u32::from(items[j].weight());
items.swap(i, j);
break;
}
}
}
}
}
//------------ SrvItem -------------------------------------------------------
#[derive(Clone, Debug)]
pub struct SrvItem {
/// The SRV record.
srv: Srv<Name<OctetsVec>>,
/// Fall back?
#[allow(dead_code)] // XXX Check if we can actually remove it.
fallback: bool,
/// A resolved answer if we have one.
resolved: Option<Vec<IpAddr>>,
}
impl SrvItem {
fn from_rdata(srv: &Srv<impl ToName>) -> Self {
SrvItem {
srv: Srv::new(
srv.priority(),
srv.weight(),
srv.port(),
srv.target().to_name(),
),
fallback: false,
resolved: None,
}
}
fn fallback(name: impl ToName, fallback_port: u16) -> Self {
SrvItem {
srv: Srv::new(0, 0, fallback_port, name.to_name()),
fallback: true,
resolved: None,
}
}
// Resolves the target.
pub async fn resolve<R: Resolver>(
self,
resolver: &R,
) -> Result<ResolvedSrvItem, io::Error>
where
R::Octets: Octets,
{
let port = self.port();
if let Some(resolved) = self.resolved {
return Ok(ResolvedSrvItem {
srv: self.srv,
resolved: {
resolved
.into_iter()
.map(|addr| SocketAddr::new(addr, port))
.collect()
},
});
}
let resolved = lookup_host(resolver, self.target()).await?;
Ok(ResolvedSrvItem {
srv: self.srv,
resolved: {
resolved
.iter()
.map(|addr| SocketAddr::new(addr, port))
.collect()
},
})
}
}
impl AsRef<Srv<Name<OctetsVec>>> for SrvItem {
fn as_ref(&self) -> &Srv<Name<OctetsVec>> {
&self.srv
}
}
impl ops::Deref for SrvItem {
type Target = Srv<Name<OctetsVec>>;
fn deref(&self) -> &Self::Target {
self.as_ref()
}
}
//------------ ResolvedSrvItems ----------------------------------------------
/// An SRV record which has itself been resolved into a [`SocketAddr`].
#[derive(Clone, Debug)]
pub struct ResolvedSrvItem {
srv: Srv<Name<OctetsVec>>,
resolved: Vec<SocketAddr>,
}
impl ResolvedSrvItem {
/// Returns the resolved address for this record.
pub fn resolved(&self) -> &[SocketAddr] {
&self.resolved
}
}
impl AsRef<Srv<Name<OctetsVec>>> for ResolvedSrvItem {
fn as_ref(&self) -> &Srv<Name<OctetsVec>> {
&self.srv
}
}
impl ops::Deref for ResolvedSrvItem {
type Target = Srv<Name<OctetsVec>>;
fn deref(&self) -> &Self::Target {
self.as_ref()
}
}
//------------ SrvError ------------------------------------------------------
#[derive(Debug)]
pub enum SrvError {
LongName,
MalformedAnswer,
Query(io::Error),
}
impl fmt::Display for SrvError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SrvError::LongName => write!(f, "name too long"),
SrvError::MalformedAnswer => write!(f, "malformed answer"),
SrvError::Query(e) => write!(f, "error executing query {}", e),
}
}
}
impl std::error::Error for SrvError {}
impl From<io::Error> for SrvError {
fn from(err: io::Error) -> SrvError {
SrvError::Query(err)
}
}
impl From<ParseError> for SrvError {
fn from(_: ParseError) -> SrvError {
SrvError::MalformedAnswer
}
}