base64/engine/general_purpose/decode_suffix.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162
use crate::{
engine::{general_purpose::INVALID_VALUE, DecodeMetadata, DecodePaddingMode},
DecodeError, DecodeSliceError, PAD_BYTE,
};
/// Decode the last 0-4 bytes, checking for trailing set bits and padding per the provided
/// parameters.
///
/// Returns the decode metadata representing the total number of bytes decoded, including the ones
/// indicated as already written by `output_index`.
pub(crate) fn decode_suffix(
input: &[u8],
input_index: usize,
output: &mut [u8],
mut output_index: usize,
decode_table: &[u8; 256],
decode_allow_trailing_bits: bool,
padding_mode: DecodePaddingMode,
) -> Result<DecodeMetadata, DecodeSliceError> {
debug_assert!((input.len() - input_index) <= 4);
// Decode any leftovers that might not be a complete input chunk of 4 bytes.
// Use a u32 as a stack-resident 4 byte buffer.
let mut morsels_in_leftover = 0;
let mut padding_bytes_count = 0;
// offset from input_index
let mut first_padding_offset: usize = 0;
let mut last_symbol = 0_u8;
let mut morsels = [0_u8; 4];
for (leftover_index, &b) in input[input_index..].iter().enumerate() {
// '=' padding
if b == PAD_BYTE {
// There can be bad padding bytes in a few ways:
// 1 - Padding with non-padding characters after it
// 2 - Padding after zero or one characters in the current quad (should only
// be after 2 or 3 chars)
// 3 - More than two characters of padding. If 3 or 4 padding chars
// are in the same quad, that implies it will be caught by #2.
// If it spreads from one quad to another, it will be an invalid byte
// in the first quad.
// 4 - Non-canonical padding -- 1 byte when it should be 2, etc.
// Per config, non-canonical but still functional non- or partially-padded base64
// may be treated as an error condition.
if leftover_index < 2 {
// Check for error #2.
// Either the previous byte was padding, in which case we would have already hit
// this case, or it wasn't, in which case this is the first such error.
debug_assert!(
leftover_index == 0 || (leftover_index == 1 && padding_bytes_count == 0)
);
let bad_padding_index = input_index + leftover_index;
return Err(DecodeError::InvalidByte(bad_padding_index, b).into());
}
if padding_bytes_count == 0 {
first_padding_offset = leftover_index;
}
padding_bytes_count += 1;
continue;
}
// Check for case #1.
// To make '=' handling consistent with the main loop, don't allow
// non-suffix '=' in trailing chunk either. Report error as first
// erroneous padding.
if padding_bytes_count > 0 {
return Err(
DecodeError::InvalidByte(input_index + first_padding_offset, PAD_BYTE).into(),
);
}
last_symbol = b;
// can use up to 8 * 6 = 48 bits of the u64, if last chunk has no padding.
// Pack the leftovers from left to right.
let morsel = decode_table[b as usize];
if morsel == INVALID_VALUE {
return Err(DecodeError::InvalidByte(input_index + leftover_index, b).into());
}
morsels[morsels_in_leftover] = morsel;
morsels_in_leftover += 1;
}
// If there was 1 trailing byte, and it was valid, and we got to this point without hitting
// an invalid byte, now we can report invalid length
if !input.is_empty() && morsels_in_leftover < 2 {
return Err(DecodeError::InvalidLength(input_index + morsels_in_leftover).into());
}
match padding_mode {
DecodePaddingMode::Indifferent => { /* everything we care about was already checked */ }
DecodePaddingMode::RequireCanonical => {
// allow empty input
if (padding_bytes_count + morsels_in_leftover) % 4 != 0 {
return Err(DecodeError::InvalidPadding.into());
}
}
DecodePaddingMode::RequireNone => {
if padding_bytes_count > 0 {
// check at the end to make sure we let the cases of padding that should be InvalidByte
// get hit
return Err(DecodeError::InvalidPadding.into());
}
}
}
// When encoding 1 trailing byte (e.g. 0xFF), 2 base64 bytes ("/w") are needed.
// / is the symbol for 63 (0x3F, bottom 6 bits all set) and w is 48 (0x30, top 2 bits
// of bottom 6 bits set).
// When decoding two symbols back to one trailing byte, any final symbol higher than
// w would still decode to the original byte because we only care about the top two
// bits in the bottom 6, but would be a non-canonical encoding. So, we calculate a
// mask based on how many bits are used for just the canonical encoding, and optionally
// error if any other bits are set. In the example of one encoded byte -> 2 symbols,
// 2 symbols can technically encode 12 bits, but the last 4 are non-canonical, and
// useless since there are no more symbols to provide the necessary 4 additional bits
// to finish the second original byte.
let leftover_bytes_to_append = morsels_in_leftover * 6 / 8;
// Put the up to 6 complete bytes as the high bytes.
// Gain a couple percent speedup from nudging these ORs to use more ILP with a two-way split.
let mut leftover_num = (u32::from(morsels[0]) << 26)
| (u32::from(morsels[1]) << 20)
| (u32::from(morsels[2]) << 14)
| (u32::from(morsels[3]) << 8);
// if there are bits set outside the bits we care about, last symbol encodes trailing bits that
// will not be included in the output
let mask = !0_u32 >> (leftover_bytes_to_append * 8);
if !decode_allow_trailing_bits && (leftover_num & mask) != 0 {
// last morsel is at `morsels_in_leftover` - 1
return Err(DecodeError::InvalidLastSymbol(
input_index + morsels_in_leftover - 1,
last_symbol,
)
.into());
}
// Strangely, this approach benchmarks better than writing bytes one at a time,
// or copy_from_slice into output.
for _ in 0..leftover_bytes_to_append {
let hi_byte = (leftover_num >> 24) as u8;
leftover_num <<= 8;
*output
.get_mut(output_index)
.ok_or(DecodeSliceError::OutputSliceTooSmall)? = hi_byte;
output_index += 1;
}
Ok(DecodeMetadata::new(
output_index,
if padding_bytes_count > 0 {
Some(input_index + first_padding_offset)
} else {
None
},
))
}