diff --git a/qoi-bench/src/main.rs b/qoi-bench/src/main.rs index 49553e3..3ca5643 100644 --- a/qoi-bench/src/main.rs +++ b/qoi-bench/src/main.rs @@ -126,8 +126,8 @@ impl Codec for CodecQoiFast { Ok(qoi_fast::qoi_encode_to_vec(&img.data, img.width, img.height, img.channels, 0)?) } - fn decode(data: &[u8], img: &Image) -> Result> { - Ok(qoi_fast::qoi_decode_to_vec(data, img.channels)?.1) + fn decode(data: &[u8], _img: &Image) -> Result> { + Ok(qoi_fast::qoi_decode_to_vec(data)?.1) } } diff --git a/src/decode.rs b/src/decode.rs index 15f6a3b..d88af6a 100644 --- a/src/decode.rs +++ b/src/decode.rs @@ -1,9 +1,9 @@ // TODO: can be removed once https://github.com/rust-lang/rust/issues/74985 is stable -use bytemuck::{cast_slice_mut, Pod}; +use bytemuck::{cast_slice, cast_slice_mut, Pod}; use crate::consts::{ QOI_HEADER_SIZE, QOI_OP_DIFF, QOI_OP_INDEX, QOI_OP_LUMA, QOI_OP_RGB, QOI_OP_RGBA, QOI_OP_RUN, - QOI_PADDING_SIZE, + QOI_PADDING, QOI_PADDING_SIZE, }; use crate::error::{Error, Result}; use crate::header::Header; @@ -49,12 +49,13 @@ macro_rules! decode { #[inline] fn qoi_decode_impl_slice( data: &[u8], out: &mut [u8], -) -> Result<()> +) -> Result where Pixel: SupportedChannels, [u8; N]: Pod, { let mut pixels = cast_slice_mut::<_, [u8; N]>(out); + let data_len = data.len(); let mut data = data; let mut index = [[0_u8; N]; 256]; @@ -112,31 +113,19 @@ where *px_out = px; } - Ok(()) -} - -pub trait MaybeChannels { - fn maybe_channels(self) -> Option; -} - -impl MaybeChannels for u8 { - #[inline] - fn maybe_channels(self) -> Option { - Some(self) + if unlikely(data.len() < QOI_PADDING_SIZE) { + return Err(Error::UnexpectedBufferEnd); + } else if unlikely(cast_slice::<_, [u8; QOI_PADDING_SIZE]>(data)[0] != QOI_PADDING) { + return Err(Error::InvalidPadding); } -} -impl MaybeChannels for Option { - #[inline] - fn maybe_channels(self) -> Option { - self - } + Ok(data_len.saturating_sub(data.len()).saturating_sub(QOI_PADDING_SIZE)) } #[inline] fn qoi_decode_impl_slice_all( data: &[u8], out: &mut [u8], channels: u8, src_channels: u8, -) -> Result<()> { +) -> Result { match (channels, src_channels) { (3, 3) => qoi_decode_impl_slice::<3, false>(data, out), (3, 4) => qoi_decode_impl_slice::<3, true>(data, out), @@ -150,33 +139,82 @@ fn qoi_decode_impl_slice_all( } #[inline] -pub fn qoi_decode_to_buf( - mut out: impl AsMut<[u8]>, data: impl AsRef<[u8]>, channels: impl MaybeChannels, -) -> Result
{ - let (out, data) = (out.as_mut(), data.as_ref()); - let header = Header::decode(data)?; - let channels = channels.maybe_channels().unwrap_or(header.channels); - let size = header.n_pixels() * channels as usize; - if unlikely(out.len() < size) { - return Err(Error::OutputBufferTooSmall { size: out.len(), required: size }); - } - let data = &data[QOI_HEADER_SIZE..]; // can't panic - qoi_decode_impl_slice_all(data, out, header.channels, channels).map(|_| header) +pub fn qoi_decode_to_buf(buf: impl AsMut<[u8]>, data: impl AsRef<[u8]>) -> Result
{ + let mut decoder = QoiDecoder::new(&data)?; + decoder.decode_to_buf(buf)?; + Ok(*decoder.header()) } #[inline] -pub fn qoi_decode_to_vec( - data: impl AsRef<[u8]>, channels: impl MaybeChannels, -) -> Result<(Header, Vec)> { - let data = data.as_ref(); - let header = Header::decode(data)?; - let channels = channels.maybe_channels().unwrap_or(header.channels); - let mut out = vec![0; header.n_pixels() * channels as usize]; - let data = &data[QOI_HEADER_SIZE..]; // can't panic - qoi_decode_impl_slice_all(data, &mut out, header.channels, channels).map(|_| (header, out)) +pub fn qoi_decode_to_vec(data: impl AsRef<[u8]>) -> Result<(Header, Vec)> { + let mut decoder = QoiDecoder::new(&data)?; + let out = decoder.decode_to_vec()?; + Ok((*decoder.header(), out)) } #[inline] pub fn qoi_decode_header(data: impl AsRef<[u8]>) -> Result
{ Header::decode(data) } + +#[derive(Clone)] +pub struct QoiDecoder<'a> { + data: &'a [u8], + header: Header, + channels: u8, +} + +impl<'a> QoiDecoder<'a> { + #[inline] + pub fn new(data: &'a (impl AsRef<[u8]> + ?Sized)) -> Result { + let data = data.as_ref(); + let header = Header::decode(data)?; + let data = &data[QOI_HEADER_SIZE..]; // can't panic + Ok(Self { data, header, channels: header.channels }) + } + + #[inline] + pub fn with_channels(mut self, channels: u8) -> Self { + self.channels = channels; + self + } + + #[inline] + pub fn channels(&self) -> u8 { + self.channels + } + + #[inline] + pub fn header(&self) -> &Header { + &self.header + } + + #[inline] + pub fn data(self) -> &'a [u8] { + self.data + } + + #[inline] + pub fn decode_to_buf(&mut self, mut buf: impl AsMut<[u8]>) -> Result<()> { + let buf = buf.as_mut(); + let size = self.header.n_pixels() * self.channels as usize; + if unlikely(buf.len() < size) { + return Err(Error::OutputBufferTooSmall { size: buf.len(), required: size }); + } + let n_read = + qoi_decode_impl_slice_all(self.data, buf, self.channels, self.header.channels)?; + self.data = &self.data[n_read..]; // can't panic + Ok(()) + } + + #[inline] + pub fn decode_to_vec(&mut self) -> Result> { + if unlikely(self.channels > 4) { + // prevent accidental over-allocations + cold(); + return Err(Error::InvalidChannels { channels: self.channels }); + } + let mut out = vec![0; self.header.n_pixels() * self.channels as usize]; + self.decode_to_buf(&mut out).map(|_| out) + } +} diff --git a/src/error.rs b/src/error.rs index 0f710e8..80ae6d9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -16,6 +16,7 @@ pub enum Error { InvalidMagic { magic: u32 }, UnexpectedBufferEnd, InvalidColorSpace { colorspace: u8 }, + InvalidPadding, } pub type Result = StdResult; @@ -51,6 +52,9 @@ impl Display for Error { Self::InvalidColorSpace { colorspace } => { write!(f, "invalid color space: {} (expected 0 or 1)", colorspace) } + Self::InvalidPadding => { + write!(f, "invalid padding (stream end marker)") + } } } } diff --git a/tests/test_ref.rs b/tests/test_ref.rs index c281684..a908dff 100644 --- a/tests/test_ref.rs +++ b/tests/test_ref.rs @@ -98,7 +98,7 @@ fn test_reference_images() -> Result<()> { let encoded = qoi_encode_to_vec(&img.data, img.width, img.height, img.channels, 0)?; let expected = fs::read(qoi_path)?; compare_slices(&png_name, "encoding", &encoded, &expected)?; - let (_header, decoded) = qoi_decode_to_vec(&expected, img.channels)?; + let (_header, decoded) = qoi_decode_to_vec(&expected)?; compare_slices(&png_name, "decoding", &decoded, &img.data)?; }